In [1]:
class Args:
    def __init__(self):
        self.attention_probs_dropout_prob = 0.2
        self.hidden_dropout_prob = 0.2
        self.drop_path_rate = 0.1
        self.classifier_dropout = 0.1
        self.ksteps = 1000

args = Args()

In [2]:
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation, SegformerConfig
import torch
import numpy as np
from PIL import Image


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = SegformerConfig.from_pretrained("nvidia/segformer-b2-finetuned-ade-512-512")
config.hidden_dropout_prob = args.hidden_dropout_prob
config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
config.drop_path_rate = args.drop_path_rate
config.classifier_dropout = args.classifier_dropout
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b2-finetuned-ade-512-512", config=config)
model.segformer.encoder.patch_embeddings[0].proj = torch.nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
model.decode_head.classifier = torch.nn.Conv2d(768, 1, kernel_size=(1, 1), stride=(1, 1))

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b2-finetuned-ade-512-512 and are newly initialized: ['segformer.encoder.test.bias', 'segformer.encoder.test.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
class trainer:
    def __init__(self, model, optimizer, lr_scheduler, train_dataloader, val_dataloader, logger, loss_fn):
        self.model = model
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.logger = logger
        self.loss_fn = loss_fn
        self.running_loss = 0
        self.running_f1 = 0
        self.step = 0

    
    def train_step(self, X, Y, ROI):
        self.model.train()
        self.optimizer.zero_grad()
        pred = self.model(X).logits * ROI
        Y = Y * ROI
        loss = self.loss_fn(pred, Y)
        loss.backward()
        self.optimizer.step()
        f1 = (2 * (pred > 0).float() * Y.float()).sum() / ((pred > 0).float() + Y.float()).sum()
        
        self.running_loss += loss.item()
        self.running_f1 += f1.item()

    def validate(self, X, Y, ROI):
        self.model.eval()
        with torch.no_grad():
            pref = self.model(X).logits * ROI
            Y = Y * ROI
            loss = self.loss_fn(pref, Y)
            f1 = (2 * (pref > 0).float() * Y.float()).sum() / ((pref > 0).float() + Y.float()).sum()
        return loss, f1
    
    
    def train_epoch(self):
        for i, (X, Y, ROI) in enumerate(self.train_dataloader):
            self.train_step(X, Y, ROI)
            if i%1000 == 0: 
                self.logger.log({"loss": self.running_loss / 100, "f1": self.running_f1 / 100})
                print(f"Epoch {self.epoch}, Step {i}, Loss: {self.running_loss / 100}, F1: {self.running_f1 / 100}")
                for i, (val_X, val_Y, val_ROI) in enumerate(self.val_dataloader):
                    val_runnning_loss, val_running_f1 = 0, 0
                    val_loss, val_f1 = self.validate(val_X, val_Y, val_ROI)
                    val_runnning_loss += val_loss
                    val_running_f1 += val_f1
                
                self.logger.log({"val_loss": val_runnning_loss / len(self.val_dataloader), "val_f1": val_running_f1 / len(self.val_dataloader)})
                print(f"Validation Loss: {val_runnning_loss / len(self.val_dataloader)}, Validation F1: {val_running_f1 / len(self.val_dataloader)}")

                self.running_loss = 0
                self.running_f1 = 0
                self.lr_scheduler.step()
                self.step += 1
                if self.step >= args.ksteps:
                    raise StopIteration

    def train(self):
        while True:
            try:
                self.train_epoch()
            except StopIteration:
                break

from dataloader import CustomDataset
import wandb
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.ksteps)
train_dataset = CustomDataset("/mnt/fastdata/preaug_cdnet/", "/mnt/fastdata/CDNet", 1, "train")
val_dataset = CustomDataset("/mnt/fastdata/preaug_cdnet/", "/mnt/fastdata/CDNet", 1, "val")
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=20)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=20)
loss_fn = torch.nn.BCEWithLogitsLoss()

wandb.init(project="Remeow")
logger = wandb
trainer = trainer(model, optimizer, lr_scheduler, train_dataloader, val_dataloader, logger, loss_fn)


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwguo6358[0m ([33m3dsmile[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
trainer.train()

RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dimension 3