In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from SSLRS.segdata import *
import pandas as pd
import torch

In [None]:
from fastai.vision.all import *
from mmcv.utils import Config, DictAction, get_git_hash
from mmseg.models import build_segmentor
from semantic_segmentation.backbone import xcit
import timm
from SSLRS.xcit import XCiT
from mmseg.models.decode_heads import FPNHead
from mmseg.models.necks import FPN
from SSLRS.segdata import *

In [None]:
class FCNNET(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone=XCiT(num_classes=0, in_chans=10,patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
        self.neck=FPN(in_channels=[384, 384, 384, 384],out_channels=384,num_outs=4)
        self.decode_head =FPNHead(feature_strides=[8, 8, 8, 8],in_channels=[384, 384, 384, 384],channels=128,num_classes=11,in_index=[0, 1, 2, 3],dropout_ratio=0.1,align_corners=False)
        self.init_weights()
    def forward(self, x):
        x =self.decode_head(self.neck(self.backbone(x)))
        return x
    def init_weights(self):
        state_dict=torch.load('./SSLmodels/oldtrain/checkpoint.pth')['teacher']
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        self.backbone.load_state_dict(state_dict,strict=False)
        print('load')
def mIOU( pred,label, num_classes=12):
    pred =F.upsample_nearest(pred,scale_factor=4)  
    pred = F.softmax(pred, dim=1)              
    pred = torch.argmax(pred, dim=1).squeeze(1)
    label=MSTensorImage(label)
    iou_list = list()
    present_iou_list = list()
    pred = pred.view(-1)
    label = label.view(-1)
    # Note: Following for loop goes from 0 to (num_classes-1)
    # and ignore_index is num_classes, thus ignore_index is
    # not considered in computation of IoU.
    for sem_class in range(num_classes):
        if sem_class!=0:
            pred_inds = (pred == sem_class)
            target_inds = (label == sem_class)
            if target_inds.long().sum().item() == 0:
                iou_now = float('nan')
            else: 
                intersection_now = (pred_inds[target_inds]).long().sum().item()
                union_now = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection_now
                iou_now = float(intersection_now) / float(union_now)
                present_iou_list.append(iou_now)
            iou_list.append(iou_now)
    return np.mean(present_iou_list)

class OhemCrossEntropy(nn.Module): 
    def __init__(self, ignore_label=-1, thres=0.7, 
        min_kept=100000, weight=None): 
        super(OhemCrossEntropy, self).__init__() 
        self.thresh = thres
        self.min_kept = max(1, min_kept)
        self.ignore_label = ignore_label 
        self.weight=weight
        self.criterion = nn.CrossEntropyLoss(weight=weight, 
                                             ignore_index=ignore_label, 
                                             reduction='none') 
    
    def forward(self, score, target, **kwargs):
        target=MSTensorImage(target.long())
        score =F.upsample_nearest(score,scale_factor=4)
        pred = F.softmax(score, dim=1)
        pixel_losses = self.criterion(score, target).contiguous().view(-1)
        mask = target.contiguous().view(-1) != self.ignore_label         
        
        tmp_target = target.clone() 
        tmp_target[tmp_target == self.ignore_label] = 0 
        pred = pred.gather(1, tmp_target.unsqueeze(1)) 
        pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort()
        min_value = pred[min(self.min_kept, pred.numel() - 1)] 
        threshold = max(min_value, self.thresh) 
        
        pixel_losses = pixel_losses[mask][ind]
        pixel_losses = pixel_losses[pred < threshold] 
#         iouloss=Lov.lovasz_softmax(score,target,class_weight=self.weight)
        return pixel_losses.mean()
      

In [None]:
model=FCNNET()

load


In [None]:
df=pd.read_csv('train.csv')

In [None]:
db = DataBlock(blocks=(TransformBlock(type_tfms=partial(MSTensorImage.create)),
                       TransformBlock(type_tfms=partial(MSMask.create)),
                      ),
               get_x=ColReader('names'),
                get_y=ColReader('masks'),
               splitter=RandomSplitter(seed=10),
               item_tfms=[aug,aug2],            
              )

In [None]:
dls = db.dataloaders(source=df,bs=6,num_workers=8,pin_memory=True)

In [None]:
loss=OhemCrossEntropy()
learn = Learner(dls,model,metrics=mIOU,loss_func=loss).to_fp16()


In [None]:
learn.fit_one_cycle(200, 1e-4,cbs=[CSVLogger(fname='SegSSLoldweight.csv',append=True),SaveModelCallback(monitor='mIOU',fname='SegSSLoldweight')])

epoch,train_loss,valid_loss,mIOU,time


In [None]:
learn.fit_one_cycle(200, 1e-4,cbs=[CSVLogger(fname='SegnoSSL.csv',append=True),SaveModelCallback(monitor='mIOU',fname='SegnoSSL')])