In [1]:
from torchvision import utils
from custom import *
from dataloader import *
from utils import *
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import time
import math
from tqdm import tqdm
import gc
import os

In [2]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.xavier_uniform_(m.weight.data)
        torch.nn.init.xavier_uniform_(m.bias.data.view(m.bias.data.shape[0],1))
        #a = math.sqrt(3) * math.sqrt(2/m.bias.data.shape[0])
        #torch.nn.init._no_grad_uniform_(m.bias.data, -a, a)
        
        


In [3]:

    
def train(model, criterion, epochs, train_loader, val_loader, test_loader, use_gpu, name):
    
    #Create non-existing logfiles
    logname = 'logfile.txt'
    i = 1
    if os.path.exists('logfile.txt') == True:
        
        logname = 'logfile' + str(i) + '.txt'
        while os.path.exists('logfile' + str(i) + '.txt'):
            i+=1
            logname = 'logfile' + str(i) + '.txt'

    print('Loading results to logfile: ' + logname)
    with open(logname, "a") as file:
        file.write("Lofile DATA: Validation Loss and Accuracy\n") 
    
    logname_summary = 'logfile' + str(i) + '_summary.txt'    
    print('Loading Summary to : ' + logname_summary) 
    
    
    optimizer = optim.Adam(model.parameters(), lr=5e-3)
    if use_gpu:
        device = torch.device("cuda:0")
        model = torch.nn.DataParallel(model)
        model.to(device)
        
        
    
    val_loss_set = []
    val_acc_set = []
    val_iou_set = []
    
    
    training_loss = []
    
    # Early Stop criteria
    minLoss = 1e6
    minLossIdx = 0
    earliestStopEpoch = 10
    earlyStopDelta = 5
    for epoch in range(epochs):
        ts = time.time()

                
                  
        for iter, (inputs, tar, labels) in tqdm(enumerate(train_loader)):

            optimizer.zero_grad()
            del tar
            
            if use_gpu:
                inputs = inputs.to(device)# Move your inputs onto the gpu
                labels = labels.to(device) # Move your labels onto the gpu
            
                
            outputs = model(inputs)
            del inputs
            loss = criterion(outputs, Variable(labels.long()))
            del labels
            del outputs

            loss.backward()
            loss = loss#.item()
            optimizer.step()

            if iter % 10 == 0:
                print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss))

        
        # calculate val loss each epoch
        val_loss, val_acc, val_iou = val(model, val_loader, criterion, use_gpu)
        val_loss_set.append(val_loss)
        val_acc_set.append(val_acc)
        val_iou_set.append(val_iou)
        
        print("epoch {}, time {}, train loss {}, val loss {}, val acc {}, val iou {}".format(epoch, time.time() - ts,
                                                                                                loss, val_loss,
                                                                                                val_acc,
                                                                                                val_iou))        
        training_loss.append(loss)
        
        with open(logname, "a") as file:
            file.write("writing!\n")
            file.write("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))
            file.write("\n training Loss:   " + str(loss.item()))
            file.write("\n Validation Loss: " + str(val_loss_set[-1]))
            file.write("\n Validation acc:  " + str(val_acc_set[-1]))
            file.write("\n Validation iou:  " + str(val_iou_set[-1]) + "\n ")                                             
                                                                                                
                                                                                                
        
        # Early stopping
        if val_loss < minLoss:
            # Store new best
            torch.save(model, name)
            minLoss = val_loss#.item()
            minLossIdx = epoch
            
        # If passed min threshold, and no new min has been reached for delta epochs
        elif epoch > earliestStopEpoch and (epoch - minLossIdx) > earlyStopDelta:
            print("Stopping early at {}".format(minLossIdx))
            break
        # TODO what is this for?
        #model.train()

        
        
    with open(logname_summary, "a") as file:
            file.write("Summary!\n")
            file.write("Stopped early at {}".format(minLossIdx))
            file.write("\n training Loss:   " + str(training_loss))        
            file.write("\n Validation Loss: " + str(val_loss_set))
            file.write("\n Validation acc:  " + str(val_acc_set))
            file.write("\n Validation iou:  " + str(val_iou_set) + "\n ")
            
        
    return val_loss_set, val_acc_set, val_iou_set


def val(model, val_loader, criterion, use_gpu):
    
    # set to evaluation mode 
    model.eval()

    softmax = nn.Softmax(dim = 1)
    
    loss = []
    pred = []
    acc = []
    
    IOU_init = False
    if use_gpu:
        device = torch.device("cuda:0")
        
        #model.to(device)
        
    for iter, (X, tar, Y) in tqdm(enumerate(val_loader)):
        
        if not IOU_init:
            IOU_init = True
            IOU = np.zeros((1,19))
            
        if use_gpu:
            inputs = X.to(device)
            labels = Y.to(device)
            
        else:
            inputs, labels = X, Y

            
        with torch.no_grad():   
            outputs = model(inputs)    
            loss.append(criterion(outputs, labels.long()).item())
            prediction = softmax(outputs) 
            acc.append(pixel_acc(prediction, labels))
            IOU = IOU + np.array(iou(prediction, labels))
        
    
    acc = sum(acc)/len(acc)
    avg_loss = sum(loss)/len(loss) 
    IOU = IOU/iter  
    
    return avg_loss, acc, IOU      
       
    
    
    
def test(model, use_gpu):
    
    softmax = nn.Softmax(dim = 1)
    
    pred = []
    acc = []
    if use_gpu:
        device = torch.device("cuda:0")
        
        model.to(device)
    
    IOU_init = False
    for iter, (X, tar, Y) in enumerate(test_loader):
        
        if not IOU_init:
            IOU_init = True
            IOU = np.zeros((1,tar.shape[1]))
        
        if use_gpu:
            inputs = X.to(device)
            labels = Y.to(device)
        else:
            inputs, labels = X, Y
                    
        
        outputs = model(inputs)  
        
        prediction = softmax(outputs)
        acc.append(pixel_acc(prediction, labels))
        IOU = IOU + np.array(iou(prediction, Y))
        
    acc = sum(acc)/len(acc)        
    IOU = IOU/iter

    #Complete this function - Calculate accuracy and IoU 
    # Make sure to include a softmax after the output from your model
    
    return acc, IOU
    

In [None]:
def checkM():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                print(type(obj), obj.size())
        except:
            pass

if __name__ == "__main__":
    train_dataset = CityScapesDataset(csv_file='train.csv')
    val_dataset = CityScapesDataset(csv_file='val.csv')
    test_dataset = CityScapesDataset(csv_file='test.csv')
    train_loader = DataLoader(dataset=train_dataset,
                          batch_size=2,
                          num_workers=8,
                          shuffle=True)
    val_loader = DataLoader(dataset=val_dataset,
                          batch_size=2,
                          num_workers=8,
                          shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,
                          batch_size=2,
                          num_workers=8,
                          shuffle=True)
    
    
    epochs     = 100
    criterion = torch.nn.CrossEntropyLoss()
    # Fix magic number
    model = Custom(n_class=34)
    model.apply(init_weights)
    
    
    epochs     = 100
    use_gpu = torch.cuda.is_available()

    train(model, criterion, epochs, train_loader, val_loader, test_loader, use_gpu, "Custom")
    
    
    model.load_state_dict(torch.load('./save_param'))
    

Loading results to logfile: logfile6.txt
Loading Summary to : logfile6_summary.txt


2it [00:08,  5.76s/it]

epoch0, iter0, loss: 3.6874866485595703


11it [00:16,  1.19s/it]

epoch0, iter10, loss: 1.993118166923523


22it [00:24,  1.35it/s]

epoch0, iter20, loss: 1.8649331331253052


31it [00:31,  1.00it/s]

epoch0, iter30, loss: 1.9954395294189453


41it [00:39,  1.00s/it]

epoch0, iter40, loss: 1.6205285787582397


51it [00:47,  1.00s/it]

epoch0, iter50, loss: 1.6849114894866943


61it [00:55,  1.01s/it]

epoch0, iter60, loss: 2.595975637435913


71it [01:03,  1.01s/it]

epoch0, iter70, loss: 1.6406168937683105


81it [01:11,  1.01s/it]

epoch0, iter80, loss: 1.6844669580459595


91it [01:19,  1.00s/it]

epoch0, iter90, loss: 1.7309153079986572


102it [01:27,  1.36it/s]

epoch0, iter100, loss: 1.552109956741333


111it [01:35,  1.00s/it]

epoch0, iter110, loss: 1.9126060009002686


121it [01:42,  1.01s/it]

epoch0, iter120, loss: 1.9794156551361084


131it [01:50,  1.01s/it]

epoch0, iter130, loss: 1.9525591135025024


141it [01:58,  1.01s/it]

epoch0, iter140, loss: 1.9359159469604492


151it [02:06,  1.01s/it]

epoch0, iter150, loss: 1.6958041191101074


161it [02:14,  1.01s/it]

epoch0, iter160, loss: 1.8270578384399414


171it [02:22,  1.01s/it]

epoch0, iter170, loss: 1.509909987449646


181it [02:30,  1.01s/it]

epoch0, iter180, loss: 1.74209463596344


191it [02:38,  1.01s/it]

epoch0, iter190, loss: 2.3612515926361084


201it [02:46,  1.01s/it]

epoch0, iter200, loss: 1.8669507503509521


212it [02:54,  1.35it/s]

epoch0, iter210, loss: 1.3149124383926392


221it [03:02,  1.01s/it]

epoch0, iter220, loss: 2.234273672103882


231it [03:10,  1.01s/it]

epoch0, iter230, loss: 1.3949358463287354


241it [03:18,  1.01s/it]

epoch0, iter240, loss: 1.4348853826522827


251it [03:26,  1.01s/it]

epoch0, iter250, loss: 2.076296091079712


261it [03:34,  1.01s/it]

epoch0, iter260, loss: 1.560309886932373


271it [03:41,  1.01s/it]

epoch0, iter270, loss: 1.882183313369751


281it [03:49,  1.01s/it]

epoch0, iter280, loss: 1.7053391933441162


291it [03:57,  1.00s/it]

epoch0, iter290, loss: 1.5853402614593506


301it [04:05,  1.01s/it]

epoch0, iter300, loss: 1.1527342796325684


311it [04:13,  1.01s/it]

epoch0, iter310, loss: 1.6357016563415527


322it [04:21,  1.36it/s]

epoch0, iter320, loss: 1.8910924196243286


331it [04:29,  1.00s/it]

epoch0, iter330, loss: 1.5187195539474487


341it [04:37,  1.01s/it]

epoch0, iter340, loss: 1.5076831579208374


351it [04:45,  1.01s/it]

epoch0, iter350, loss: 1.5353976488113403


361it [04:53,  1.01s/it]

epoch0, iter360, loss: 1.406674861907959


371it [05:01,  1.00s/it]

epoch0, iter370, loss: 1.5654430389404297


381it [05:09,  1.00s/it]

epoch0, iter380, loss: 1.4617027044296265


391it [05:17,  1.01s/it]

epoch0, iter390, loss: 1.4081463813781738


401it [05:24,  1.01s/it]

epoch0, iter400, loss: 1.5545734167099


411it [05:32,  1.01s/it]

epoch0, iter410, loss: 1.5250369310379028


421it [05:40,  1.01s/it]

epoch0, iter420, loss: 1.8036680221557617


431it [05:48,  1.01s/it]

epoch0, iter430, loss: 1.4008185863494873


441it [05:56,  1.01s/it]

epoch0, iter440, loss: 1.6135172843933105


451it [06:04,  1.01s/it]

epoch0, iter450, loss: 1.3355287313461304


461it [06:12,  1.01s/it]

epoch0, iter460, loss: 1.883622169494629


471it [06:20,  1.01s/it]

epoch0, iter470, loss: 3.248440980911255


481it [06:28,  1.01s/it]

epoch0, iter480, loss: 1.0854566097259521


491it [06:36,  1.01s/it]

epoch0, iter490, loss: 1.3718429803848267


501it [06:44,  1.01s/it]

epoch0, iter500, loss: 1.318204641342163


511it [06:52,  1.01s/it]

epoch0, iter510, loss: 1.2746317386627197


521it [07:00,  1.01s/it]

epoch0, iter520, loss: 1.1491320133209229


531it [07:08,  1.01s/it]

epoch0, iter530, loss: 1.8958449363708496


541it [07:15,  1.01s/it]

epoch0, iter540, loss: 1.7857221364974976


552it [07:24,  1.36it/s]

epoch0, iter550, loss: 1.5925185680389404


561it [07:31,  1.00s/it]

epoch0, iter560, loss: 1.4385963678359985


571it [07:39,  1.01s/it]

epoch0, iter570, loss: 1.4021074771881104


581it [07:47,  1.01s/it]

epoch0, iter580, loss: 1.5336085557937622


591it [07:55,  1.01s/it]

epoch0, iter590, loss: 1.3051955699920654


601it [08:03,  1.01s/it]

epoch0, iter600, loss: 2.598149538040161


611it [08:11,  1.01s/it]

epoch0, iter610, loss: 1.6250596046447754


621it [08:19,  1.01s/it]

epoch0, iter620, loss: 1.4818041324615479


631it [08:27,  1.01s/it]

epoch0, iter630, loss: 1.6522588729858398


641it [08:35,  1.01s/it]

epoch0, iter640, loss: 1.568432331085205


651it [08:43,  1.01s/it]

epoch0, iter650, loss: 1.553775429725647


661it [08:51,  1.01s/it]

epoch0, iter660, loss: 1.9504880905151367


671it [08:59,  1.01s/it]

epoch0, iter670, loss: 2.2421982288360596


681it [09:07,  1.01s/it]

epoch0, iter680, loss: 1.2865948677062988


691it [09:15,  1.01s/it]

epoch0, iter690, loss: 1.7982301712036133


701it [09:23,  1.01s/it]

epoch0, iter700, loss: 1.2802987098693848


711it [09:30,  1.01s/it]

epoch0, iter710, loss: 1.3186581134796143


721it [09:38,  1.01s/it]

epoch0, iter720, loss: 1.7242095470428467


731it [09:46,  1.01s/it]

epoch0, iter730, loss: 1.3863152265548706


741it [09:54,  1.01s/it]

epoch0, iter740, loss: 1.330289363861084


751it [10:02,  1.01s/it]

epoch0, iter750, loss: 1.985758900642395


761it [10:10,  1.01s/it]

epoch0, iter760, loss: 1.432651162147522


772it [10:18,  1.35it/s]

epoch0, iter770, loss: 1.645565390586853


781it [10:26,  1.00s/it]

epoch0, iter780, loss: 2.4437646865844727


791it [10:34,  1.01s/it]

epoch0, iter790, loss: 2.2000205516815186


801it [10:42,  1.01s/it]

epoch0, iter800, loss: 1.079526662826538


811it [10:50,  1.01s/it]

epoch0, iter810, loss: 1.183671236038208


821it [10:58,  1.01s/it]

epoch0, iter820, loss: 1.3593817949295044


831it [11:06,  1.01s/it]

epoch0, iter830, loss: 1.2402440309524536


841it [11:14,  1.01s/it]

epoch0, iter840, loss: 1.1999868154525757


851it [11:22,  1.01s/it]

epoch0, iter850, loss: 1.832255482673645


861it [11:29,  1.01s/it]

epoch0, iter860, loss: 1.4642510414123535


871it [11:37,  1.01s/it]

epoch0, iter870, loss: 1.2308200597763062


881it [11:45,  1.01s/it]

epoch0, iter880, loss: 1.2463765144348145


891it [11:53,  1.01s/it]

epoch0, iter890, loss: 1.378576636314392


901it [12:01,  1.01s/it]

epoch0, iter900, loss: 1.3146802186965942


911it [12:09,  1.01s/it]

epoch0, iter910, loss: 1.721304178237915


921it [12:17,  1.01s/it]

epoch0, iter920, loss: 1.7041531801223755


931it [12:25,  1.01s/it]

epoch0, iter930, loss: 1.5169447660446167


941it [12:33,  1.01s/it]

epoch0, iter940, loss: 1.5990018844604492


951it [12:41,  1.01s/it]

epoch0, iter950, loss: 1.2297601699829102


961it [12:49,  1.01s/it]

epoch0, iter960, loss: 1.1201708316802979


971it [12:57,  1.01s/it]

epoch0, iter970, loss: 1.2224146127700806


981it [13:05,  1.01s/it]

epoch0, iter980, loss: 1.148391604423523


991it [13:13,  1.01s/it]

epoch0, iter990, loss: 1.443306803703308


1001it [13:20,  1.01s/it]

epoch0, iter1000, loss: 1.2253929376602173


1011it [13:28,  1.01s/it]

epoch0, iter1010, loss: 1.1726213693618774


1021it [13:36,  1.01s/it]

epoch0, iter1020, loss: 0.9981908202171326


1031it [13:44,  1.01s/it]

epoch0, iter1030, loss: 1.7010509967803955


1041it [13:52,  1.01s/it]

epoch0, iter1040, loss: 1.4781415462493896


1051it [14:00,  1.01s/it]

epoch0, iter1050, loss: 1.2948811054229736


1061it [14:08,  1.01s/it]

epoch0, iter1060, loss: 1.5925710201263428


1072it [14:16,  1.31it/s]

epoch0, iter1070, loss: 1.1010804176330566


1081it [14:24,  1.00s/it]

epoch0, iter1080, loss: 1.263534665107727


1091it [14:32,  1.01s/it]

epoch0, iter1090, loss: 1.2224453687667847


1101it [14:40,  1.01s/it]

epoch0, iter1100, loss: 1.913055181503296


1111it [14:48,  1.01s/it]

epoch0, iter1110, loss: 1.0002293586730957


1121it [14:56,  1.01s/it]

epoch0, iter1120, loss: 1.1669576168060303


1131it [15:04,  1.01s/it]

epoch0, iter1130, loss: 1.2238112688064575


1141it [15:12,  1.01s/it]

epoch0, iter1140, loss: 1.090611219406128


1151it [15:20,  1.01s/it]

epoch0, iter1150, loss: 2.012876272201538


1161it [15:27,  1.01s/it]

epoch0, iter1160, loss: 1.4918643236160278


1171it [15:35,  1.01s/it]

epoch0, iter1170, loss: 1.6038801670074463


1181it [15:43,  1.01s/it]

epoch0, iter1180, loss: 1.4809925556182861


1191it [15:51,  1.01s/it]

epoch0, iter1190, loss: 1.4864859580993652


1201it [15:59,  1.01s/it]

epoch0, iter1200, loss: 1.1725022792816162


1211it [16:07,  1.01s/it]

epoch0, iter1210, loss: 1.0369457006454468


1221it [16:15,  1.01s/it]

epoch0, iter1220, loss: 2.030799627304077


1232it [16:23,  1.35it/s]

epoch0, iter1230, loss: 1.5562337636947632


1241it [16:31,  1.00s/it]

epoch0, iter1240, loss: 1.307465672492981


1252it [16:39,  1.36it/s]

epoch0, iter1250, loss: 1.1879545450210571


1261it [16:47,  1.01s/it]

epoch0, iter1260, loss: 1.4099230766296387


1271it [16:55,  1.01s/it]

epoch0, iter1270, loss: 1.3429831266403198


1281it [17:03,  1.01s/it]

epoch0, iter1280, loss: 1.3568859100341797


1291it [17:11,  1.01s/it]

epoch0, iter1290, loss: 1.9024884700775146


1301it [17:19,  1.01s/it]

epoch0, iter1300, loss: 1.170203685760498


1311it [17:27,  1.01s/it]

epoch0, iter1310, loss: 1.755142092704773


1321it [17:34,  1.01s/it]

epoch0, iter1320, loss: 1.161767601966858


1331it [17:42,  1.01s/it]

epoch0, iter1330, loss: 1.2132571935653687


1341it [17:50,  1.01s/it]

epoch0, iter1340, loss: 1.3372704982757568


1351it [17:58,  1.01s/it]

epoch0, iter1350, loss: 1.0758706331253052


1361it [18:06,  1.01s/it]

epoch0, iter1360, loss: 1.6031546592712402


1371it [18:14,  1.01s/it]

epoch0, iter1370, loss: 1.841813325881958


1381it [18:22,  1.01s/it]

epoch0, iter1380, loss: 1.151539921760559


1391it [18:30,  1.01s/it]

epoch0, iter1390, loss: 1.1419142484664917


1401it [18:38,  1.01s/it]

epoch0, iter1400, loss: 1.3750450611114502


1411it [18:46,  1.01s/it]

epoch0, iter1410, loss: 0.9705425500869751


1421it [18:54,  1.01s/it]

epoch0, iter1420, loss: 1.1502299308776855


1431it [19:02,  1.01s/it]

epoch0, iter1430, loss: 1.1053636074066162


1441it [19:10,  1.01s/it]

epoch0, iter1440, loss: 1.4113233089447021


1451it [19:18,  1.01s/it]

epoch0, iter1450, loss: 0.9583432674407959


1461it [19:25,  1.01s/it]

epoch0, iter1460, loss: 1.2146198749542236


1471it [19:33,  1.01s/it]

epoch0, iter1470, loss: 1.9242018461227417


1481it [19:41,  1.01s/it]

epoch0, iter1480, loss: 2.211298942565918


1488it [19:47,  1.01it/s]
250it [01:38,  2.83it/s]


epoch 0, time 1287.831246137619, train loss 1.0826466083526611, val loss 1.8919261527061462, val acc 45.88572874069214, val iou [[5.16785754e-01 2.41206919e-02 2.66387837e-01            nan
             nan 1.50781210e-05            nan 2.28940815e-02
  2.34129118e-01            nan 5.13516133e-01 3.17048929e-02
             nan 1.25761269e-01            nan            nan
             nan            nan            nan]]


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
2it [00:03,  2.09s/it]

epoch1, iter0, loss: 1.4098838567733765


11it [00:09,  1.18it/s]

epoch1, iter10, loss: 1.4923175573349


21it [00:15,  1.27it/s]

epoch1, iter20, loss: 1.571415662765503


31it [00:21,  1.27it/s]

epoch1, iter30, loss: 1.6353065967559814


42it [00:28,  1.72it/s]

epoch1, iter40, loss: 1.6815193891525269


51it [00:34,  1.27it/s]

epoch1, iter50, loss: 1.7767894268035889


61it [00:40,  1.27it/s]

epoch1, iter60, loss: 2.078259229660034


71it [00:46,  1.26it/s]

epoch1, iter70, loss: 1.630918025970459


81it [00:52,  1.26it/s]

epoch1, iter80, loss: 1.689947247505188


91it [00:59,  1.27it/s]

epoch1, iter90, loss: 1.6568522453308105


101it [01:05,  1.27it/s]

epoch1, iter100, loss: 1.191127896308899


111it [01:11,  1.27it/s]

epoch1, iter110, loss: 1.499384880065918


121it [01:17,  1.27it/s]

epoch1, iter120, loss: 1.7950937747955322


131it [01:24,  1.27it/s]

epoch1, iter130, loss: 1.8031656742095947


141it [01:30,  1.27it/s]

epoch1, iter140, loss: 2.07659649848938


151it [01:36,  1.27it/s]

epoch1, iter150, loss: 1.6762373447418213


161it [01:42,  1.27it/s]

epoch1, iter160, loss: 1.233525037765503


171it [01:48,  1.26it/s]

epoch1, iter170, loss: 1.606899619102478


181it [01:55,  1.27it/s]

epoch1, iter180, loss: 1.510339617729187


191it [02:01,  1.26it/s]

epoch1, iter190, loss: 1.3760383129119873


199it [02:05,  1.59it/s]