   ## Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation

In [None]:
from LRR_lib import *

In [None]:
model = LRR32s()
vgg16 = models.vgg16(pretrained=True)
model.init_vgg16_params(vgg16)
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
model = model.cuda()

In [None]:
#Variable passed to model
Epoch = 100
img_row = 224
img_col = 224
batch_size_voc = 6
Momemtum = 0.9
Weight_Decay = 5e-4
Learning_Rate = 2e-4

In [None]:
data_aug = Compose([RandomRotate(10),RandomHorizontallyFlip()])
results = []
d_scale1 = 4;
data_loader = get_loader('pascal')
data_path = get_data_path('pascal')

t_loader = data_loader(data_path, is_transform=True, img_size=(img_row, img_col), augmentations=data_aug)
v_loader = data_loader(data_path, is_transform=True, split='val', img_size=(img_row, img_col))

n_classes = t_loader.n_classes
trainloader = data.DataLoader(t_loader, batch_size=batch_size_voc, num_workers=8, shuffle=True)
valloader = data.DataLoader(v_loader, batch_size=batch_size_voc, num_workers=8)

running_metrics = runningScore(n_classes)
    
optimizer = torch.optim.SGD(model.parameters(), lr=Learning_Rate, momentum=Momemtum, weight_decay=Weight_Decay)

loss_fn = cross_entropy2d

best_iou = -100.0 
for epoch in range(Epoch):
    model.train()
    for i, (images, labels) in enumerate(trainloader):
        labels = label_downscale(labels, d_scale1, batch_size_voc, img_row);
        images = Variable(images.cuda())
        labels = Variable(labels.cuda()) 
        
        optimizer.zero_grad()
        outputs = model(images)    
        loss = loss_fn(input=outputs, target=labels)

        loss.backward()
        optimizer.step()
        if (i+1)*batch_size_voc > 8600:
           break;
        if (i+1) % 10 == 0:
            print("Epoch [%d/%d/%d] Loss: %.10f" % ((i+1)*batch_size_voc,epoch+1, Epoch, loss.data[0]))
    model.eval()
            
    for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
        if labels_val.size()[0]== batch_size_voc :
                labels_val = label_downscale(labels_val, d_scale1, batch_size_voc, img_row);
                images_val = Variable(images_val.cuda(), volatile=True)
                labels_val = Variable(labels_val.cuda(), volatile=True)

                outputs = model(images_val)
                
                pred = outputs.data.max(1)[1].cpu().numpy()
                gt = labels_val.data.cpu().numpy()
                running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()
    for k, v in score.items():
        print(k, v)
        results.append(v)
        running_metrics.reset()

In [None]:
model = LRR()
model.init_decov_2x_4x_params()
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
model = model.cuda()

In [None]:
#Variable passed to model
Epoch = 150
img_row = 224
img_col = 224
batch_size_voc = 6
Momemtum = 0.9
Weight_Decay = 5e-4
Learning_Rateft = 1e-5
Learning_Rate4x = 1e-4
Learning_Rate2x = 1e-4

In [None]:
data_aug = Compose([RandomRotate(10),RandomHorizontallyFlip()])
results = []
scale8x = 4;
scale4x = 2;

data_loader = get_loader('pascal')
data_path = get_data_path('pascal')
print(data_path)

t_loader = data_loader(data_path, is_transform=True, img_size=(img_row, img_col), augmentations=data_aug)
v_loader = data_loader(data_path, is_transform=True, split='val', img_size=(img_row, img_col))

n_classes = t_loader.n_classes
trainloader = data.DataLoader(t_loader, batch_size=batch_size_voc, num_workers=8, shuffle=True)
valloader = data.DataLoader(v_loader, batch_size=batch_size_voc, num_workers=8)

# Setup Metrics
running_metrics = runningScore(n_classes)

optimizerft = torch.optim.SGD(model.parameters(), lr=Learning_Rateft, momentum=Momemtum)
optimizer4x = torch.optim.SGD(model.parameters(), lr=Learning_Rate4x, momentum=Momemtum,weight_decay=Weight_Decay)
optimizer2x = torch.optim.SGD(model.parameters(), lr=Learning_Rate2x, momentum=Momemtum,weight_decay=Weight_Decay)

loss_fn = cross_entropy2d

best_iou = -100.0 
for epoch in range(Epoch):
    model.train()
    for i, (images, labels) in enumerate(trainloader):
        
        labels8x = label_downscale(labels, scale8x, batch_size_voc, img_row);
        labels4x = label_downscale(labels, scale4x, batch_size_voc, img_row);
        
        images = Variable(images.cuda())
        labels2x = Variable(labels.cuda())
        labels4x = Variable(labels4x.cuda())
        labels8x = Variable(labels8x.cuda())
               
        outputs8x,outputs4x,outputs2x = model(images)        
        
        if epoch >= 100 and epoch <=Epoch:    
            optimizerft.zero_grad()        
            loss8x = loss_fn(input=outputs8x, target=labels8x)
            loss8x.backward(retain_graph=True)
            optimizerft.step()
            
            optimizerft.zero_grad()        
            loss4x = loss_fn(input=outputs4x, target=labels4x)
            loss4x.backward(retain_graph=True)
            optimizerft.step()
            
            optimizerft.zero_grad()        
            loss2x = loss_fn(input=outputs2x, target=labels2x)
            loss2x.backward()
            optimizerft.step()
            print("Epoch [%d/%d/%d] Loss4x: %.4f " % ((i+1)*batch_size_voc,epoch+1, Epoch,loss2x.data[0]))
        if epoch >= 0 and epoch <=50:        
            optimizer4x.zero_grad()        
            loss4x = loss_fn(input=outputs4x, target=labels4x)
            loss4x.backward()
            optimizer4x.step()
            print("Epoch [%d/%d/%d] Loss4x: %.4f " % ((i+1)*batch_size_voc,epoch+1, Epoch,loss4x.data[0]))
            
        if epoch >50  and epoch <=100:
            optimizer2x.zero_grad()        
            loss2x = loss_fn(input=outputs2x, target=labels2x)
            loss2x.backward()
            optimizer2x.step()
            print("Epoch [%d/%d/%d] Loss4x: %.4f " % ((i+1)*batch_size_voc,epoch+1, Epoch,loss2x.data[0]))    
        
        if (i+1)*batch_size_voc > 8700:
           break;

    model.eval()
            
    for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
        if labels_val.size()[0]== batch_size_voc :
            
                images_val = Variable(images_val.cuda(), volatile=True)
                labels_val = Variable(labels_val.cuda(), volatile=True)

                outputs8x,outputs4x,outputs2x = model(images_val)
                
                pred = outputs2x.data.max(1)[1].cpu().numpy()
                gt = labels_val.data.cpu().numpy()
                running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()
    for k, v in score.items():
        print(k, v)
        results.append(v)
        running_metrics.reset()