<a href="https://colab.research.google.com/github/yun-xiaoxiong/yun-xiaoxiong/blob/main/Unet_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import os
import glob
import random
import matplotlib.pyplot as plt
torch.cuda.empty_cache()

In [None]:
learning_rate=1e-6
batch_size=1
epoch=100
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Unit(nn.Module):
    def __init__(self,ch_in,ch_out,kernel_size=3,stride=1,padding=1):
        super(Unit,self).__init__()
        self.an_unit=nn.Sequential(                         
            nn.Conv2d(ch_in,ch_out,kernel_size=kernel_size,stride=stride,padding=padding),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(),
            nn.Conv2d(ch_out,ch_out,kernel_size=kernel_size,stride=stride,padding=padding),
            nn.BatchNorm2d(ch_out),
            nn.ReLU()
        )
    def forward(self,x):
        return self.an_unit(x)

class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        self.pool=nn.MaxPool2d(2,stride=2)
        self.cr1=Unit(1,64)
        self.cr2=Unit(64,128)
        self.cr3=Unit(128,256)
        self.cr4=Unit(256,512)
        self.cr5=Unit(512,1024)
        
        self.up1=Unit(1024,512)
        self.up2=Unit(512,256)
        self.up3=Unit(256,128)
        self.up4=Unit(128,64)
        
        self.conv=nn.Sequential(
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,2,kernel_size=1,stride=1,padding=0)
        )
        
    def forward(self,x):
        x1=self.cr1(x)
        x2=self.cr2(self.pool(x1))
        x3=self.cr3(self.pool(x2))
        x4=self.cr4(self.pool(x3))
        x5=self.cr5(self.pool(x4))
        x5_up=self.up1(nn.functional.interpolate(x5,64))
        #print(x4.shape)
        #print(x5_up.shape)
        x6=self.up1(torch.cat([x4,x5_up],1))
        x6_up=self.up2(nn.functional.interpolate(x6,128))
        x7=self.up2(torch.cat([x3,x6_up],1))
        x7_up=self.up3(nn.functional.interpolate(x7,256))
        x8=self.up3(torch.cat([x2,x7_up],1))
        x8_up=self.up4(nn.functional.interpolate(x8,512))
        x9=self.up4(torch.cat([x1,x8_up],1))
        x10=self.conv(x9)
        return x10

In [None]:
x=torch.randn(2,512,512,requires_grad=True)
y=torch.ones(2,512,512,requires_grad=True)
model=Unet()
model=model.cuda()
optimizer=optim.Adam(model.parameters(),lr=learning_rate)
loss_fn=nn.CrossEntropyLoss()

def train():
  model.train()
  for ep in range(epoch):
      feature=x
      label=y
      a=[]
      feature=feature.cuda()
      label=label.cuda()
      out=model(feature.unsqueeze(1))
      pred=out.transpose(1,2).transpose(2,3).contiguous().view(-1,2)
      
      true=label.view(-1).long()
      loss=loss_fn(pred,true)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      a.append(loss.item())
      print("epoch:",ep+1,"  loss: ",np.mean(a))
      test()

RuntimeError: ignored

In [None]:
def test():
    feature=x
    label=y
    model.eval()
    feature=feature.to(device)
    label=label.to(device)
    out=model(feature.unsqueeze(1))
    pred=nn.functional.softmax(out.cpu(),dim=1) # torch.Size([4, 2, 512, 512])- > [4,2,512,512]
    label=label.cpu()
    with torch.no_grad():
      pred=pred[:,1,:,:]
      torch.squeeze(pred,1)
      pred[pred >= 0.5] = 1
      pred[pred < 0.5] = 0
      # pred = np.array(pred.data.cpu()[0])[0]
      jiao=(label*pred).sum().item()
      bing=(label+pred).clamp(max=1).sum().item()
      IOU=jiao/bing
      print("交集：",jiao,"  并集:",bing)
    accuracy=IOU*100
    print("IOU accuracy: ",accuracy,"% ")
    print("----------------------------------------------------------------")
    return IOU

In [None]:
train()

RuntimeError: ignored