In [1]:
HEADS = 8
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.v2
import torchvision.transforms.functional

import numpy as np
import cv2
import os
import cv2 as cv
import matplotlib.pyplot as plt
import random

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, mode="train"):
        self.mode = mode
        if mode == "train":
            self.datapath = "/home/wg25r/fastdata/gasvid/train"
            self.images = [i for i in os.listdir(f"{self.datapath}/masks") if int(i.split("_")[-1].split(".")[0]) > 500]
        else: 
            self.datapath = "/home/wg25r/fastdata/gasvid/val"
            self.images = [i for i in os.listdir(f"{self.datapath}/masks") if int(i.split("_")[-1].split(".")[0]) > 500]

        self.ignore_before = 20
        self.ignore_after = 40 
        self.space_trans = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.RandomResizedCrop(448, scale=(0.6, 3)), 
            torchvision.transforms.v2.RandomHorizontalFlip(0.5),
            torchvision.transforms.v2.RandomRotation(20), 
            torchvision.transforms.v2.RandomApply(
                [torchvision.transforms.v2.ElasticTransform(alpha=50)], p=0.3
            ),
            torchvision.transforms.v2.RandomApply(
                [torchvision.transforms.v2.RandomPerspective()], p=0.3
            ),
            torchvision.transforms.v2.RandomApply(
                [torchvision.transforms.v2.RandomAffine(20,  scale=(0.5, 1.1))], p=0.3
            ),
        ])
        self.color_trans = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.RandomApply([torchvision.transforms.v2.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.4)
        ])

    def __len__(self): 
        if self.mode == "train":
            return int(len(self.images)) 
        else: 
            return int(len(self.images))
        

    def __getitem__(self, idx): 
        filename = self.images[idx]

        current_frame = cv2.resize(cv2.imread(f"{self.datapath}/in/{filename}"), (512, 512))
        long_bg = cv2.resize(cv2.imread(f"{self.datapath}/long/{filename}"), (512, 512))
        short_bg = cv2.resize(cv2.imread(f"{self.datapath}/short/{filename}"), (512, 512)) 
        label_ = cv2.imread(f"{self.datapath}/gt/{filename}")
        label = (label_ == 255) * 255.0
        ROI =  (label_ != 85) * 255.0
        label = cv2.resize(label, (512, 512))
        ROI = cv2.resize(ROI, (512, 512))

        current_frame = torch.tensor(current_frame).permute(2,0,1)
        long_bg = torch.tensor(long_bg).permute(2,0,1)
        short_bg = torch.tensor(short_bg).permute(2,0,1)
        label = torch.tensor(label).permute(2,0,1)
        X = torch.cat([current_frame, long_bg, short_bg], axis=0)
        # print(X.shape)
        
        Y = label.max(0)[None,:,:] 

        if self.mode == "train":  
            # X = self.color_trans(X) 
            YX = torch.cat((Y, X), axis=0) 
            YX = self.space_trans(YX)
            Y = YX[:1]/255.0  
            X = YX[1:]/255.0 
            X += torch.randn(X.shape) * 0.005
            X += torch.tensor(cv.resize(np.random.normal(0, 0.005, (10, 10)), X.shape[1:])).float()
            X *= 1 + torch.randn(9)[:,None,None] * 0.005
        else:
            X = X/255.0
            Y = Y/255.0 
        Y = torchvision.transforms.functional.resize(Y, (448//4, 448//4))[0]
        X[X<0]=0
        return X, Y 

In [3]:
vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')

Using cache found in /home/wg25r/.cache/torch/hub/facebookresearch_dino_main


In [4]:
class BCA(nn.Module):
    """
    Background-CurrentFrame Attention
    """
    def __init__(self, dim=384):
        super(BCA, self).__init__()
        self.cross_attention = nn.MultiheadAttention(dim, HEADS, dropout=0.1, batch_first=True, kdim=dim, vdim=dim * 2)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
    
    def forward(self, background, current_frame):
        """
        background: torch.Tensor, shape (batch, L, dim)
        current_frame: torch.Tensor, shape (batch, L, dim)
        """

        attn_output, _ = self.cross_attention(query=current_frame, key=background, value=torch.concatenate([background, current_frame], dim=-1))
        attn_output = self.norm1(attn_output + current_frame)
        mlp_output = self.mlp(attn_output)
        mlp_output = self.norm2(mlp_output + attn_output)
        return mlp_output

In [None]:
class BCA_list(nn.Module):
    """
    Background-CurrentFrame Attention
    """
    def __init__(self, dim=384):
        super(BCA, self).__init__()
        self.cross_attention = nn.MultiheadAttention(dim, HEADS, dropout=0.1, batch_first=True, kdim=dim, vdim=dim * 2)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
    
    def forward(self, long_background, short_background, toupimahuxikunduzikk current_frame):
        """
        background: torch.Tensor, shape (batch, L, dim)
        current_frame: torch.Tensor, shape (batch, L, dim)
        """

        attn_output, _ = self.cross_attention(query=current_frame, key=background, value=torch.concatenate([background, current_frame], dim=-1))
        attn_output = self.norm1(attn_output + current_frame)
        mlp_output = self.mlp(attn_output)
        mlp_output = self.norm2(mlp_output + attn_output)
        return mlp_output

In [6]:
class MyModel(nn.Module):
    def __init__(self, backbone):
        super(MyModel, self).__init__()
        self.backbone = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.bca_seq = nn.Sequential(*[BCA() for _ in range(4)])
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bicubic'),
            nn.Conv2d(384, 1, 1)
        ) 
        

    def forward(self, bg, current_frame):
        bg = self.backbone.get_intermediate_layers(bg)[0][:,1:,:]
        current_frame = self.backbone.get_intermediate_layers(current_frame)[0][:,1:,:]
        for bca in self.bca_seq:
            current_frame = bca(bg, current_frame)
        return self.decoder(current_frame.reshape(bg.shape[0], 448//8, 448//8, 384).permute(0,3,1,2))
    


# MyModel(vits8)(torch.randn(1, 3, 448, 448), torch.randn(1, 3, 448, 448)).shape


In [7]:
train_dataloader = torch.utils.data.DataLoader(MyDataset("train"), batch_size=8, shuffle=True, num_workers=80, persistent_workers=True, prefetch_factor=3)
val_dataloader = torch.utils.data.DataLoader(MyDataset("val"), batch_size=8, shuffle=True, num_workers=80, persistent_workers=True, prefetch_factor=3)
print("Train", len(train_dataloader), "Val", len(val_dataloader))

Train 14884 Val 77


In [8]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [9]:
def iou_loss(pred, target):
    pred = torch.sigmoid(pred)
    assert pred.shape == target.shape
    e = 1e-6
    iou = ((pred * target).sum() + e) / (pred.sum() + target.sum() - (pred * target).sum() + e)
    return 1 - iou

In [10]:
import wandb
mymodel = MyModel(vits8).cuda()
optimizer = torch.optim.Adam(mymodel.parameters(), lr=3e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6)
wandb.init(config={
    "lr": optimizer.param_groups[0]["lr"],
    "batch_size": train_dataloader.batch_size,
}, resume=False)




[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwguo6358[0m ([33m3dsmile[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
loss_fn = iou_loss
epoch = 0
while 1:
    for i, (X, Y) in enumerate(train_dataloader):
        if i%500 == 0:
            scheduler.step() 

        if i%200==0:
            with torch.no_grad():
                mymodel.eval()
                total_loss = 0
                total_iou = 0
                for X_val, Y_val in val_dataloader:
                    X_val = X_val.cuda().float()
                    Y_val = Y_val.cuda().float()

                    pred = mymodel(X_val[:,0][:,None,:,:].repeat(1,3,1,1)
                    , X_val[:,1][:,None,:,:].repeat(1,3,1,1)
                    )
                    loss = loss_fn(pred, Y_val.cuda().unsqueeze(1)) 
                    total_loss += loss.item()
                    iou = (((pred > 0) & (Y_val.cuda().unsqueeze(1) > 0)).float().mean() + 1e-6)/(((pred > 0) | ((Y_val.cuda().unsqueeze(1) > 0))).float().mean() + 1e-6)
                    pred = pred[0].reshape((448//4, 448//4))
                    total_iou += iou.float()
                total_iou /= len(val_dataloader) 
                if total_iou > 0.5:
                    torch.save(mymodel.state_dict(), "final.pth") 

                total_loss /= len(val_dataloader) 
                wandb.log({"val_iou": total_iou, "has_gas_ratio":(Y_val.sum((1,2)) > 0).float().sum()/len(Y_val),
                    "real": wandb.Image(Y_val[0].cpu().detach().numpy().reshape(448//4, 448//4)), 
                    "pred": wandb.Image(pred.cpu().detach().numpy()>0),
                      "X_val": wandb.Image(X_val[0][0].cpu().detach().numpy()),
                      "X_BGS": wandb.Image(X_val[0][1].cpu().detach().numpy()),
                      "val_loss": total_loss}) 
                print("Val loss", total_loss, "Val iou", total_iou)
                mymodel.train()
                epoch += 1

            
        optimizer.zero_grad()
        X = X.cuda().float() 
        Y = Y.cuda().float()
        pred = mymodel(X[:,0][:,None,:,:].repeat(1,3,1,1), X[:,1][:,None,:,:].repeat(1,3,1,1))
        # loss = torchvision.ops.sigmoid_focal_loss(pred, Y.cuda().unsqueeze(1), alpha=1/(labels == 1).sum(), gamma=10, reduction="mean")
        loss = loss_fn(pred, Y.cuda().unsqueeze(1))
        acc = (pred > 0) == Y.cuda().unsqueeze(1)
        # f1 = f1_score(Y.unsqueeze(1).cpu().detach().numpy().reshape(-1).astype(int), pred.cpu().detach().numpy().reshape(-1) > 0)
        iou = (((pred > 0) & (Y.cuda().unsqueeze(1) > 0)).float().mean()  + 1e-6 )/(((pred > 0) | (Y.cuda().unsqueeze(1) > 0)) + 1e-6).float().mean()
        loss.backward() 
        optimizer.step()
        wandb.log({"loss": loss.item(), "acc": acc.float().mean().item(), "iou": iou.float(), "lr": optimizer.param_groups[0]["lr"]}) # cannot do iou mean here otherwise it average non overlapping area
        # same pred as image to wandb
        if i % 2000 == 0:
            torch.save(mymodel.state_dict(), "model_ft.pth")



Val loss 0.9872101785300614 Val iou tensor(0.0195, device='cuda:0')


KeyboardInterrupt: 

In [17]:
pred.shape, Y_val.unsqueeze(1).shape

(torch.Size([8, 1, 224, 224]), torch.Size([8, 1, 112, 112]))