# 云分类器

In [None]:
import pickle
import os

In [None]:
import torch
from torchvision import transforms, datasets, models
import torch.nn as nn
import torch.nn.functional as F

## 载入数据

In [None]:
train_data_path = r"./data"
test_data_path = r"./test"

In [None]:
batch_size = 8

In [None]:
data_transform = transforms.Compose([
    transforms.RandomAffine(90),
    transforms.RandomGrayscale(),
    transforms.RandomResizedCrop(256),
    transforms.ColorJitter(),
    transforms.ToTensor()
])

In [None]:
cloud_datasets = datasets.ImageFolder(train_data_path,transform=data_transform)

In [None]:
image_loader = torch.utils.data.DataLoader(cloud_datasets,batch_size=batch_size,shuffle=True,num_workers=4)

## 构造模型

In [None]:
INPUT_PLANES = 3
FIRST_LAYER_PLANES = 10

In [None]:
class BaseBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(BaseBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes,out_planes,3,padding=1)
        self.conv2 = nn.Conv2d(out_planes,out_planes,3,padding=1)
        self.shortcut = nn.Conv2d(in_planes,out_planes,1)
        
    def forward(self, x_in):
        x1 = self.conv1(x_in)
        x1 = F.relu(x1,inplace=True)
        x1 = F.dropout(x1)
        x1 = self.conv2(x1)
        x1 = F.relu(x1,inplace=True)
        x1 = F.dropout(x1)
        x2 = self.shortcut(x_in)
        
        x_out = x1+x2
        x_out = F.instance_norm(x_out)
        x_out = F.dropout(x_out)
        x_out = F.avg_pool2d(x_out,2)
        
        return x_out
        

In [None]:
class MyNet(nn.Module):
    def __init__(self, planes, num_class):
        super(MyNet, self).__init__()
        self.layers = nn.ModuleList()
        self.first_layer = nn.Conv2d(INPUT_PLANES,FIRST_LAYER_PLANES,1,padding=0)
        self.out_layer = nn.Conv2d(planes[-1],num_class,1)
        self.fc = nn.Linear(64*num_class,num_class)
        pre_num = FIRST_LAYER_PLANES
        for p in planes:
            self.layers.append(BaseBlock(pre_num,p))
            pre_num = p
            
    def forward(self, x):
        x = self.first_layer(x)
        for l in self.layers:
            x = l(x)
        x = self.out_layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x
        

In [None]:
planes = [64,64,128,256,256]

In [None]:
classifier = MyNet(planes,10)

In [None]:
model_path = r"./model"

In [None]:
model_name = r"model.pickle"

In [None]:
try:
    with open(os.path.join(model_path,model_name),"rb") as f:
        state_dict = pickle.load(f)
        classifier.load_state_dict(state_dict)
except:
    print("can't load model")

##  训练模型

### 选择loss和优化器

In [None]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adamax(classifier.parameters())

### 定义训练流程

In [None]:
def train(model, loss, optimizer, data_loader, epochs = 1, cuda = True):
    if cuda:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model.to(device)
        model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        running_corrects = 0.0
        data_sizes = 0.0
        
        for i,data in enumerate(data_loader):
            inputs, labels = data
            if cuda:
                inputs, labels = inputs.to(device), labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            #print(outputs)
            preds = torch.argmax(outputs, dim=1)
            
            loss_result = loss(outputs, labels)
            loss_result.backward()
            optimizer.step()
            
            running_loss += loss_result.item()
            running_corrects += torch.sum(preds == labels.data)
            data_sizes += len(labels)
            
        print('[%d, %5d] loss: %.6f, acc: %.6f' %
          (epoch + 1, data_sizes, running_loss / data_sizes, running_corrects.item() / data_sizes))

### 训练模型

In [None]:
train(classifier,loss, optimizer, image_loader, epochs=100)

### 保存模型

In [None]:
with open(os.path.join(model_path,model_name),"wb") as f:
    device = torch.device("cpu")
    classifier.to(device)
    state_dict = classifier.state_dict()
    pickle.dump(state_dict,f)