In [1]:
from imports import *
from datasets.idd import *
from datasets.bdd import *
from detection.unet import *
from collections import OrderedDict
from torch_cluster import nearest
from fastprogress import master_bar, progress_bar

Unet loaded successfully


In [2]:
batch_size=8
num_epochs=1

In [3]:
path = '/home/jupyter/autonue/data'
root_img_path = os.path.join(path,'bdd100k','images','100k')
root_anno_path = os.path.join(path,'bdd100k','labels')

train_img_path = root_img_path+'/train/'
val_img_path = root_img_path+'/val/'

train_anno_json_path = root_anno_path+'/bdd100k_labels_images_train.json'
val_anno_json_path = root_anno_path+'/bdd100k_labels_images_val.json'

print("Loading files")

with open("datalists/bdd100k_train_images_path.txt", "rb") as fp:
    train_img_path_list = pickle.load(fp)
with open("datalists/bdd100k_val_images_path.txt", "rb") as fp:
    val_img_path_list = pickle.load(fp)

src_dataset = dset = BDD(train_img_path_list,train_anno_json_path,get_transform(train=True))
src_dl =  torch.utils.data.DataLoader(src_dataset, batch_size=batch_size, shuffle=True, num_workers=4,collate_fn=utils.collate_fn) 

Loading files


100%|██████████| 69863/69863 [00:02<00:00, 25953.05it/s]


In [4]:
with open("datalists/idd_images_path_list.txt", "rb") as fp:
    non_hq_img_paths = pickle.load(fp)
with open("datalists/idd_anno_path_list.txt", "rb") as fp:
    non_hq_anno_paths = pickle.load(fp)

with open("datalists/idd_hq_images_path_list.txt", "rb") as fp:
    hq_img_paths = pickle.load(fp)
with open("datalists/idd_hq_anno_path_list.txt", "rb") as fp:
    hq_anno_paths = pickle.load(fp)
    
trgt_images =  hq_img_paths #non_hq_img_paths #
trgt_annos = hq_anno_paths #non_hq_anno_paths #hq_anno_paths + 
trgt_dataset = IDD(trgt_images,trgt_annos,get_transform(train=True))
trgt_dl =  torch.utils.data.DataLoader(trgt_dataset, batch_size=batch_size, shuffle=True, num_workers=4,collate_fn=utils.collate_fn)

In [5]:
#src_dataset[0][0].shape,trgt_dataset[0][0].shape

In [6]:
class TransportBlock(nn.Module):
    def __init__(self,backbone,n_channels=256,batch_size=2):
        super(TransportBlock, self).__init__()
        self.backbone = backbone.cuda()
        self.stats = [0.485, 0.456, 0.406],[0.229, 0.224, 0.225]
        self.batch_size=2
        self.unet = Unet(n_channels).cuda()
        
        for name,p in self.backbone.named_parameters():
            p.requires_grad=False
        
    def unet_forward(self,x):
        return self.unet(x)
                
    def transport_loss(self,S_embeddings, T_embeddings, N_cluster=5):
        Loss = 0.  
        for batch in range(self.batch_size):
            S_embeddings = S_embeddings[batch].view(256,-1)
            T_embeddings = T_embeddings[batch].view(256,-1)
            
            N_random_vec =  S_embeddings[np.random.choice(S_embeddings.shape[0], N_cluster)]

            cluster_labels = nearest(S_embeddings, N_random_vec)
            cluster_centroids = torch.cat([torch.mean(S_embeddings[cluster_labels == label], dim=0).unsqueeze(0) for label in cluster_labels])

            Target_labels = nearest(T_embeddings, cluster_centroids)

            target_centroids = []
            for label in cluster_labels:
                if label in Target_labels:
                    target_centroids.append(torch.mean(T_embeddings[Target_labels == label], dim=0))
                else:
                    target_centroids.append(cluster_centroids[label])  

            target_centroids = torch.cat(target_centroids)

            dist = lambda x,y: torch.mean((x -y)**2)
            intra_class_variance = torch.cat([dist(T_embeddings[Target_labels[label]], target_centroids[label]).unsqueeze(0) for label in cluster_labels])
            centroid_distance = torch.cat([dist(target_centroids[label], cluster_centroids[label]).unsqueeze(0) for label in cluster_labels])

            Loss += torch.mean(centroid_distance*intra_class_variance) # similar to earth mover distance
        return Loss

In [7]:
def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).cpu()
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes).cpu() # replace the pre-trained head with a new one
    return model.cpu()

In [8]:
ckpt = torch.load('saved_models/bdd100k_24.pth')

In [9]:
model = get_model(12)
model.load_state_dict(torch.load('saved_models/bdd100k_24.pth')['model'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [10]:
ot = TransportBlock(model.backbone)
params = [p for p in ot.unet.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=1e-3,momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,base_lr=1e-3,max_lr=6e-3)

In [11]:
from detection import transform
transform = transform.GeneralizedRCNNTransform(min_size=800, max_size=1333, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225])
transform.eval()

GeneralizedRCNNTransform()

In [12]:
mb = master_bar(range(num_epochs))
for i in mb:
    for trgt_img, _ in progress_bar(trgt_dl,parent=mb):
        src_img, _ = next(iter(src_dl))

        src_images = list(image.cuda() for image in src_img)
        trgt_images = list(image.cuda() for image in trgt_img)

        src_images, _ = transform(src_images, None)
        src_features = ot.backbone(src_images.tensors)[0]

        trgt_images, _ = transform(trgt_images, None)
        trgt_features = ot.backbone(trgt_images.tensors)[0]
        
        torch.save(src_features,'src_features.pth')
        torch.save(trgt_features,'trgt_features.pth')
        
        modified_trgt_features = ot.unet_forward(trgt_features)
        
        torch.save(modified_trgt_features,'modified_trgt_features.pth')
        
        break
        #print(src_features.shape,modified_trgt_features.shape)
        
        # pad if dim of feature maps are not same
        if src_features.shape!=modified_trgt_features.shape:
            print("Earlier", src_features.shape,modified_trgt_features.shape)
            print("Fixing")
            if src_features.size(3)<336:
                src_features = F.pad(src_features,(336-src_features.size(3),0,0,0)).contiguous()
            if modified_trgt_features.size(3)>192:
                modified_trgt_features = F.pad(modified_trgt_features,(0,0,192-modified_trgt_features.size(2),0)).contiguous()
            if modified_trgt_features.size(3)<336:
                modified_trgt_features = F.pad(modified_trgt_features,(336-modified_trgt_features.size(3),0,0,0)).contiguous()
        ############################################################  
        #print("Now", src_features.shape,modified_trgt_features.shape)
        assert src_features.shape==modified_trgt_features.shape

        loss = ot.transport_loss(src_features,modified_trgt_features)

        print ("transport_loss: ",loss.item(),"lr: ", optimizer.param_groups[0]["lr"])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        del src_images,trgt_images,src_features,trgt_features,_
    break

In [13]:
torch.save({
            'model_state_dict': ot.unet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, 'saved_models/unet.pth')