To Do's 
1. Create a PyTorch model with BatchNorm2D layer
2. Train the model with cityscapes. 
3. Post-train to get priors. 
4. Use TTA during test time with foggy cityscapes. 

In [2]:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleConvModel(nn.Module):
    def __init__(self):
        super(SimpleConvModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(kernel_size=8, stride=8)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(kernel_size=8, stride=8)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(kernel_size=8, stride=8)
        self.fc = nn.Linear(128, 10)  # Assuming input image size is 3x512x1024

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Instantiate the model
model = SimpleConvModel()
print(model)


SimpleConvModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=8, stride=8, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=8, stride=8, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool2d(kernel_size=8, stride=8, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=128, out_features=10, bias=True)
)


In [3]:
# Load cityscapes data
from torch.utils.data import DataLoader

from dataloader import Dataloader as CustomDataloader


augmentations = False
data_dir = "/users/Sadman/Test_Time_Domain_Adaptation/data/"
preprocessing_config = {"mean":[0,0,0], "std_dev":[1,1,1]}
train_fetcher = CustomDataloader('train', augmentations, data_dir, preprocessing_config)
train_dataloader = DataLoader(train_fetcher, batch_size=6, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train 
for imgs, labels in train_dataloader:
    imgs = imgs.to(device)
    labels = labels.to(device)
    print("imgs shape", imgs.shape)
    print("labels shape", labels.shape)
    break

Total Data 2975
imgs shape torch.Size([6, 3, 512, 1024])
labels shape torch.Size([6, 512, 1024])


In [8]:
from tqdm import tqdm

model.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
for step, (imgs, labels) in tqdm(enumerate(train_dataloader)):
    imgs = imgs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    out = model(imgs)
    targets = labels.sum(dim=1).sum(dim=1).long() % 10
    loss = loss_func(out, targets)
    loss.backward()
    optimizer.step()
    
    if step == 10: 
        break

10it [00:09,  1.05it/s]


In [9]:
import albumentations as A

def aug(imgs):
    augmentor = A.Compose([
        A.HorizontalFlip(p=0.8),
        A.RandomBrightnessContrast(p=0.9),
        A.ShiftScaleRotate(p=0.7),
    ])
    
    aug_imgs = []
    for img in imgs:
        img = img.cpu().permute(1, 2, 0).numpy()
        augmented = augmentor(image=img)["image"]
        aug_imgs.append(augmented)
    
    return torch.tensor(aug_imgs).permute(0, 3, 1, 2)
    

# Calculate Prior
for step, (imgs, labels) in tqdm(enumerate(train_dataloader)):
    imgs = imgs.to(device) # [6, 3, 512, 1024]
    labels = labels.to(device) # [512, 1024]
    aug_imgs = aug(imgs)
    aug_imgs = aug_imgs.to(device)
    
    # Orig image grads
    optimizer.zero_grad()
    out = model(imgs)
    targets = labels.sum(dim=1).sum(dim=1).long() % 10
    loss = loss_func(out, targets)
    loss.backward()
    
    print("model.bn1.weight", model.bn1.weight)
    bn1_grad = model.bn1.weight.grad
    
    # UNet model
    # encoder_layer1_bn1_0_grad = model.encoder.layer1[0].bn1.weight.grad
    print("grad shape", bn1_grad.shape) 
    
    # Augmented image grads
    optimizer.zero_grad()
    out = model(aug_imgs)
    targets = labels.sum(dim=1).sum(dim=1).long() % 10
    loss = loss_func(out, targets)
    loss.backward()
    
    bn1_grad_aug = model.bn1.weight.grad
    print("grad shape", bn1_grad_aug.shape)
    
    if step == 1:
        break
    
    
    
    

0it [00:00, ?it/s]

0it [00:15, ?it/s]

model.bn1.weight Parameter containing:
tensor([0.9927, 1.0026, 1.0046, 0.9921, 0.9904, 1.0062, 0.9985, 0.9992, 0.9940,
        1.0030, 1.0058, 0.9946, 0.9996, 1.0009, 1.0010, 1.0061],
       device='cuda:0', requires_grad=True)
grad shape torch.Size([16])





RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

# Train 
