In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import cv2
from google.colab.patches import cv2_imshow
from tqdm import tqdm
import numpy as np
from random import randint

In [None]:
path='./drive/MyDrive/picture/'
device=torch.device("cuda:0"if torch.cuda.is_available()else "cpu")
batch_size=200
step_size=50000
epoch=100000
beta=0.99

In [None]:
class MyDataset(Dataset):
  def __init__(self,path):
    self.data=[]
    files = os.listdir(path)
    for file in files:
      img=cv2.imread(path+file)
      self.data.append(torch.FloatTensor(img/255).transpose(1,2).transpose(0,1))
      break
  def __len__(self):
    return len(self.data)
  def __getitem__(self,index):
    return self.data[index]

In [None]:
dataset=MyDataset(path)
dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)

In [None]:
def gelu(x):
  out=1+torch.tanh(np.sqrt(2/np.pi)*(x+0.044715*torch.pow(x, 3)))
  return out*x/2

In [None]:
class layer_down(nn.Module):
  def __init__(self,input_channel,output_channel):
    super(layer_down,self).__init__()
    self.first_layer=nn.Conv2d(input_channel,output_channel,(3,3),padding=1)
    self.second_layer=nn.Conv2d(output_channel,output_channel,(3,3),padding=1)
  def forward(self,X):
    mid_step=gelu(self.first_layer(X))
    end_ans=gelu(self.second_layer(mid_step))
    return end_ans

In [None]:
class layer_up(nn.Module):
  def __init__(self,input_channel,output_channel):
    super(layer_up,self).__init__()
    self.first_layer=nn.ConvTranspose2d(input_channel,output_channel,2,stride=2)
    self.second_layer=nn.Conv2d(output_channel,output_channel,3,padding=1)
    self.third_layer=nn.Conv2d(output_channel,output_channel,3,padding=1)
  def forward(self,X):
    mid_step=gelu(self.first_layer(X))
    mid_step1=gelu(self.second_layer(mid_step))
    end_ans=gelu(self.third_layer(mid_step1))
    return end_ans

In [None]:
now1=layer_down(3,3)
now2=layer_up(3,3)

In [None]:
class U_net(nn.Module):
  def __init__(self):
    super(U_net,self).__init__()
    self.one=nn.Linear(1,64)
    self.two=nn.Linear(64,64*64)
    self.pooling=nn.MaxPool2d((2,2))
    self.down_one_layer=layer_down(4,8)
    self.down_two_layer=layer_down(8,16)
    self.down_three_layer=layer_down(16,32)
    self.up_one_layer=layer_up(32,16)
    self.up_two_layer=layer_up(32,8)
    self.up_three_layer=layer_up(16,3)
  def forward(self,X,t):
    linshi=(self.two(gelu(self.one(t))).reshape(1,1,64,64)).repeat(X.size()[0],1,1,1)
    one=self.pooling(self.down_one_layer(torch.cat([X,linshi],dim=1)))
    two=self.pooling(self.down_two_layer(one))
    three=self.pooling(self.down_three_layer(two))
    four=self.up_one_layer(three)
    five=self.up_two_layer(torch.cat([four,two],dim=1))
    six=self.up_three_layer(torch.cat([five,one],dim=1))
    return six

In [None]:
model=U_net()
model=model.to(device)

In [None]:
one=[beta]
for i in range(1000):
  one.append(one[i]*beta)

In [None]:
lr=1e-6
optimizer=torch.optim.Adam(model.parameters(),lr=lr)
loss=torch.nn.MSELoss()

In [None]:
jishu=0
sum=0
for www in range(epoch):
  for number,data in enumerate(dataloader):
    data=data.to(device)
    noise=torch.randn(data.size()[0],3,64,64)
    noise=noise.to(device)
    time=randint(1,300)
    t=torch.FloatTensor([time])
    t=t.to(device)
    output=model((one[time]**0.5*data+(1-one[time])**0.5*noise),t)
    ls=loss(output,noise)
    sum=sum+ls.item()
    ls.backward()
    torch.nn.utils.clip_grad_value_(model.parameters(),2)
    optimizer.step()
    jishu=jishu+1
    optimizer.zero_grad()
    if jishu%200==0:
      print(sum/200)
      sum=0
    if jishu%2000==0:
      with torch.no_grad():
        now=torch.randn(1,3,64,64).to(device)
        linshi=torch.randn(1,3,64,64).to(device)
        for i in range(300):
          now=(1/(beta**0.5))*(now-(1-beta)/((1-one[i])**0.5)*model(now,torch.FloatTensor([300-i]).to(device)))+0.1*linshi
        now=now.reshape(3,64,64).transpose(0,1).transpose(1,2)
        cv2_imshow(now.to('cpu').numpy()*255)
      print(sum/2000)
      sum=0

In [None]:
torch.save(model.state_dict(),path+'u_net.pth')

In [None]:
for name, param in model.named_parameters():
    if param.requires_grad and param.grad is not None:
        print(param.grad)