In [None]:
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange, repeat


class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, head_num):
        super().__init__()

        self.head_num = head_num
        self.dk = (embedding_dim // head_num) ** 1 / 2

        self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
        self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)

    def forward(self, x, mask=None):
        qkv = self.qkv_layer(x)

        query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num))
        energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk

        if mask is not None:
            energy = energy.masked_fill(mask, -np.inf)

        attention = torch.softmax(energy, dim=-1)

        x = torch.einsum("... i j , ... j d -> ... i d", attention, value)

        x = rearrange(x, "b h t d -> b t (h d)")
        x = self.out_attention(x)

        return x


class MLP(nn.Module):
    def __init__(self, embedding_dim, mlp_dim):
        super().__init__()

        self.mlp_layers = nn.Sequential(
            nn.Linear(embedding_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(mlp_dim, embedding_dim),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        x = self.mlp_layers(x)

        return x


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embedding_dim, head_num, mlp_dim):
        super().__init__()

        self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)
        self.mlp = MLP(embedding_dim, mlp_dim)

        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        _x = self.multi_head_attention(x)
        _x = self.dropout(_x)
        x = x + _x
        x = self.layer_norm1(x)

        _x = self.mlp(x)
        x = x + _x
        x = self.layer_norm2(x)

        return x


class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12):
        super().__init__()

        self.layer_blocks = nn.ModuleList(
            [TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)])

    def forward(self, x):
        for layer_block in self.layer_blocks:
            x = layer_block(x)

        return x


class ViT(nn.Module):
    def __init__(self, img_dim, in_channels, embedding_dim, head_num, mlp_dim,
                 block_num, patch_dim, classification=False, num_classes=1):
        super().__init__()

        self.patch_dim = patch_dim
        self.classification = classification
        self.num_tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)

        self.projection = nn.Linear(self.token_dim, embedding_dim)
        self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))

        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))

        self.dropout = nn.Dropout(0.1)

        self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num)

        if self.classification:
            self.mlp_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        img_patches = rearrange(x,
                                'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.patch_dim, patch_y=self.patch_dim)

        batch_size, tokens, _ = img_patches.shape

        project = self.projection(img_patches)
        token = repeat(self.cls_token, 'b ... -> (b batch_size) ...',
                       batch_size=batch_size)

        patches = torch.cat([token, project], dim=1)
        patches += self.embedding[:tokens + 1, :]

        x = self.dropout(patches)
        x = self.transformer(x)
        x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]

        return x


if __name__ == '__main__':
    vit = ViT(img_dim=48,
              in_channels=13,
              patch_dim=4,
              embedding_dim=512,
              block_num=6,
              head_num=4,
              mlp_dim=512)
    print(sum(p.numel() for p in vit.parameters()))
    print(vit(torch.rand(1, 13, 48, 48)).shape)


9637376
torch.Size([1, 144, 512])


In [None]:
import torch
import torch.nn as nn
from einops import rearrange
class ACBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ACBlock, self).__init__()
        self.squre = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1)
        self.cross_ver = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3), padding=(0, 1), stride=1)
        self.cross_hor = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 1), padding=(1, 0), stride=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.ReLU = nn.ReLU(True)

    def forward(self, x):
        x1 = self.squre(x)
        x2 = self.cross_ver(x)
        x3 = self.cross_hor(x)
        return self.ReLU(self.bn(x1 + x2 + x3))

class EncoderBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, base_width=64):
        super().__init__()

        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        width = int(out_channels * (base_width / 64))

        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False)
        self.norm1 = nn.BatchNorm2d(width)

        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=2, groups=1, padding=1, dilation=1, bias=False)
        self.norm2 = nn.BatchNorm2d(width)

        self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, bias=False)
        self.norm3 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x_down = self.downsample(x)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.norm3(x)
        x = x + x_down
        x = self.relu(x)

        return x


class DecoderBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True)
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, x_concat=None):
        x = self.upsample(x)
        
        if x_concat is not None:
            x = torch.cat([x_concat, x], dim=1)
        x = self.layer(x)
        return x


class Encoder(nn.Module):
    def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False) # 7 3
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2)
        self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2)
        

        self.vit_img_dim = img_dim // patch_dim
        self.vit = ViT(self.vit_img_dim, out_channels * 4, out_channels * 4,
                       head_num, mlp_dim, block_num, patch_dim=1, classification=False)

        self.conv2 = nn.Conv2d(out_channels * 4, 256, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.BatchNorm2d(256) 

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x1 = self.relu(x)

        x2 = self.encoder1(x1)
        x = self.encoder2(x2)
        
        x = self.vit(x)
        x = rearrange(x, "b (x y) c -> b c x y", x=self.vit_img_dim, y=self.vit_img_dim)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)

        return x, x1, x2 


class Decoder(nn.Module):
    def __init__(self, out_channels, class_num):
        super().__init__()

        self.decoder2 = DecoderBottleneck(out_channels * 4, out_channels)
        self.decoder3 = DecoderBottleneck(out_channels * 2, int(out_channels * 1 / 2))
        self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 8))


    def forward(self, x, x1, x2):
        x = self.decoder2(x, x2)
        x = self.decoder3(x, x1)
        x = self.decoder4(x)

        return x


class TransUNet(nn.Module):
    def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num):
        super().__init__()

        self.encoder = Encoder(img_dim, in_channels, out_channels,
                               head_num, mlp_dim, block_num, patch_dim)

        self.decoder = Decoder(out_channels, class_num)
        #self.block1 = ACBlock(out_channels,out_channels)
        #self.block2 = ACBlock(2*out_channels,2*out_channels)
        

    def forward(self, x):
        x, x1, x2 = self.encoder(x) 
        #x1 = self.block1(x1)
        #x2 = self.block2(x2)
        x = self.decoder(x, x1, x2) 

        return x


if __name__ == '__main__':
    import torch

    transunet = TransUNet(img_dim=48,
                          in_channels=13,
                          out_channels=128,
                          head_num=4,
                          mlp_dim=512,
                          block_num=6,
                          patch_dim=8,
                          class_num=18)

    print(sum(p.numel() for p in transunet.parameters()))
    print(transunet(torch.randn(1, 13, 48, 48)).shape)


15478368
torch.Size([1, 16, 48, 48])


In [None]:
import sys
sys.path.append("/content/drive/MyDrive/Colabdocument/MTLCC-pytorch-master/src")

import torch
from utils.dataset import ijgiDataset as Dataset
from utils.snapshot import resume
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import numpy as np

could not find visdom package. try 'pip install visdom'. continue without...


In [None]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F


def attn(query, key, value):
    scores = query.transpose(1, 2) @ key / math.sqrt(query.size(1))  # (N, S, S)
    attn = F.softmax(scores, dim=-1)
    output = attn @ value.transpose(1, 2)
    return output.transpose(1, 2)  # (N, C, S)


class SAAttnMem(nn.Module):
    def __init__(self, input_dim, d_model, kernel_size):
        super().__init__()
        pad = kernel_size[0] // 2, kernel_size[1] // 2
        self.d_model = d_model
        self.input_dim = input_dim
        self.conv_h = nn.Conv2d(input_dim, d_model*3, kernel_size=1)
        self.conv_m = nn.Conv2d(input_dim, d_model*2, kernel_size=1)
        self.conv_z = nn.Conv2d(d_model*2, d_model, kernel_size=1)
        self.conv_output = nn.Conv2d(input_dim+d_model, input_dim*3, kernel_size=kernel_size, padding=pad)

    def forward(self, h, m):
        hq, hk, hv = torch.split(self.conv_h(h), self.d_model, dim=1)
        mk, mv = torch.split(self.conv_m(m), self.d_model, dim=1)
        N, C, H, W = hq.size()
        Zh = attn(hq.view(N, C, -1), hk.view(N, C, -1), hv.view(N, C, -1))  
        Zm = attn(hq.view(N, C, -1), mk.view(N, C, -1), mv.view(N, C, -1))  
        Z = self.conv_z(torch.cat([Zh.view(N, C, H, W), Zm.view(N, C, H, W)], dim=1))
        i, g, o = torch.split(self.conv_output(torch.cat([Z, h], dim=1)), self.input_dim, dim=1)
        i = torch.sigmoid(i)
        g = torch.tanh(g)
        m_next = i * g + (1 - i) * m
        h_next = torch.sigmoid(o) * m_next
        return h_next, m_next


class SAConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        pad = kernel_size[0] // 2, kernel_size[1] // 2

        self.conv = nn.Conv2d(in_channels=input_dim + hidden_dim,
                              out_channels=4 * hidden_dim,
                              kernel_size=kernel_size,
                              padding=pad)
        self.sa = SAAttnMem(input_dim=hidden_dim, d_model=d_attn, kernel_size=kernel_size)

    def forward(self, inputs, cell_state, hidden_state, memory_state):

        combined = torch.cat([inputs, hidden_state], dim=1)

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        cell_state = f * cell_state + i * g
        hidden_state = o * torch.tanh(cell_state)
        hidden_state, memory_state = self.sa(hidden_state, memory_state)
        return cell_state,hidden_state,memory_state


In [None]:
class LSTMSequentialEncoder(torch.nn.Module):
    def __init__(self, height, width, input_dim=13, hidden_dim=64, nclasses=8, kernel_size=(3,3), bias=False):
        super(LSTMSequentialEncoder, self).__init__()

        self.inconv = torch.nn.Conv3d(input_dim,hidden_dim,(1,3,3))
        self.cell = SAConvLSTMCell(input_dim,input_dim,input_dim,kernel_size)
        self.final = torch.nn.Conv2d(input_dim, nclasses, (1, 1))

    def forward(self, x1, hidden=None, state=None):

        # (b x t x c x h x w) -> (b x c x t x h x w)
        x1 = x1.permute(0,2,1,3,4)

        b, c, t, h, w = x1.shape

        hidden1 = torch.zeros(b, c, h, w).cuda()
        cell1 = torch.zeros(b, c, h, w).cuda()
        memory1 = torch.zeros(b, c, h, w).cuda()
        

        for iter in range(t):
            cell1,hidden1,memory1 = self.cell.forward(x1[:,:,iter,:,:], cell1,hidden1,memory1)


        x = self.final.forward(cell1)

        return F.log_softmax(x, dim=1)

In [None]:
class Model(nn.Module):
    def __init__(self,img_dim=48,in_channels=3,out_channels=64,head_num=4,mlp_dim=512,block_num=6,patch_dim=16,class_num=18,height=48, width=48, input_dim=18, hidden_dim=64, nclasses=8, kernel_size=(3,3), bias=False):
        super(Model,self).__init__()
        self.transunet = TransUNet(img_dim=48,
                          in_channels=13,
                          out_channels=128,
                          head_num=4,
                          mlp_dim=512,
                          block_num=6,
                          patch_dim=8,
                          class_num=18)
        
        self.lstm = LSTMSequentialEncoder(48,48,input_dim=16,nclasses=18)
        self.lstm = torch.nn.DataParallel(self.lstm).cuda()
    
    def forward(self,input,target):
        list = []
        for i in range(30):
            list.append(self.transunet(input[:,i,:,:,:]).unsqueeze(1))
        input1 = torch.cat(list,dim=1)
        input1 = input1.cuda()
        target = target.cuda()
        
        return self.lstm(input1),target


In [None]:
import torch.nn
from utils.dataset import ijgiDataset as Dataset
from utils.logger import Logger, Printer
import argparse
from utils.snapshot import save, resume
import os

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("data", type=str,help="path to dataset")
    parser.add_argument('-b', "--batchsize", default=1 , type=int, help="batch size")
    parser.add_argument('-w', "--workers", default=0, type=int, help="number of dataset worker threads")
    parser.add_argument('-e', "--epochs", default=10, type=int, help="epochs to train")
    parser.add_argument('-l', "--learning_rate", default=1e-3, type=float, help="learning rate")
    parser.add_argument('-s', "--snapshot", default=None, type=str, help="load weights from snapshot")
    parser.add_argument('-c', "--checkpoint_dir", default=None, type=str, help="directory to save checkpoints")
    return parser.parse_args()

def main(
    datadir,
    batchsize = 16,
    workers = 0,
    epochs = 10,
    lr = 1e-3,
    snapshot = None,
    checkpoint_dir = None
    ):

    traindataset = Dataset(datadir, tileids="tileids/train_fold0.tileids")
    testdataset = Dataset(datadir, tileids="tileids/test_fold0.tileids")

    nclasses = len(traindataset.classes)

    traindataloader = torch.utils.data.DataLoader(traindataset,batch_size=batchsize,shuffle=True,num_workers=workers)
    testdataloader = torch.utils.data.DataLoader(testdataset,batch_size=batchsize,shuffle=False,num_workers=workers)

    logger = Logger(columns=["loss"], modes=["train", "test"])


    network = Model()

    optimizer = torch.optim.Adam(network.parameters(), lr=lr,weight_decay=1e-3)
    loss = torch.nn.NLLLoss()

    if torch.cuda.is_available():
        network = torch.nn.DataParallel(network).cuda()
        loss = loss.cuda()

    start_epoch = 0

    if snapshot is not None:
        state = resume(snapshot,model=network, optimizer=optimizer)

        if "epoch" in state.keys():
            start_epoch = state["epoch"] + 1

        if "data" in state.keys():
            logger.resume(state["data"])

    for epoch in range(start_epoch, epochs):

        logger.update_epoch(epoch)

        print("\nEpoch {}".format(epoch))
        print("train")
        
        train_epoch(traindataloader, network, optimizer, loss, logger)
        print("\ntest")
        
        test_epoch(testdataloader, network,loss, logger)

        data = logger.get_data()


        if checkpoint_dir is not None:
            checkpoint_name = os.path.join(checkpoint_dir,"model_{:02d}.pth".format(epoch))
            save(checkpoint_name, network, optimizer, epoch=epoch, data=data)
        logger.save_csv('/content/TransUNet+SAConvLSTM/epoch_data.csv')

def train_epoch(dataloader, network, optimizer, loss, logger):
    network.train()
    printer = Printer(N=len(dataloader))
    logger.set_mode("train")

    for iteration, data in enumerate(dataloader):
        if iteration%4 == 0:
            optimizer.zero_grad()
            batch_loss = 0
        
        input, target = data 
        output,target = network.forward(input,target)
        l = loss(output, target)
        batch_loss += l/4
        l.backward()

        if (iteration+1)%4 == 0:
            stats = {"loss":batch_loss.data.cpu().numpy()}
            optimizer.step()
            
            
            printer.print(stats, iteration)
            logger.log(stats, (iteration+1)/4)

def test_epoch(dataloader, network, loss, logger):
    network.eval()
    printer = Printer(N=len(dataloader))
    logger.set_mode("test")

    with torch.no_grad():
        for iteration, data in enumerate(dataloader):
            if iteration%4 == 0:
                batch_loss = 0

            input, target = data
            
            output,target = network.forward(input,target)
            l = loss(output, target)
            batch_loss += l/4

            if (iteration+1)%4 == 0:
                stats = {"loss":batch_loss.data.cpu().numpy()}

                printer.print(stats, iteration)
                logger.log(stats, (iteration+1)/4)
            

if __name__ == "__main__":


    main(
        'data',
        batchsize=4,
        workers=8,
        epochs=100,
        lr=1e-3,
        snapshot=None,
        checkpoint_dir='/content/TransUNet+SAConvLSTM'
    )




rejected_nopath:2562, rejected_length:100, total_samples:3872


  cpuset_checked))


rejected_nopath:784, rejected_length:19, total_samples:1213

Epoch 0
train
iteration: 7/968, logs/sec: 4.72, loss: 2.93

KeyboardInterrupt: ignored