In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import math

## モデルの定義

In [2]:
class CustomModel(nn.Module):
    # emb_size ：カテゴリデータの圧縮したデータサイズ
    # n_cont 　：連続データの数
    # p        ：ドロップアウトで無効化されるノードの割合
    def __init__(self, emb_size, n_cont,p=0.5, h1=150, h2=100, h3=30, out_features=1):
        super().__init__()
        # カテゴリデータの次元削減用の定義（forwardで削減する）
        # ブランドが、40　→　20
        #　車種が、8　→　4
        self.embeds = nn.ModuleList([nn.Embedding(ni,nf) for ni,nf in emb_size])
        # ドロップアウト
        self.dropout1 = nn.Dropout(p)
        self.dropout2 = nn.Dropout(p)
        self.dropout3 = nn.Dropout(p)
        # 次元削減後のサイズ
        n_emb = sum((nf for ni,nf in emb_size))
        # バッチ正規化
        self.bn_cont = nn.BatchNorm1d(n_cont)
        
        # １階層目 (入力層のノード数はn_cont + n_emb)
        self.fc1 = nn.Linear(n_cont + n_emb , h1)
        self.bn1 = nn.BatchNorm1d(h1)
        # ２階層目
        self.fc2 = nn.Linear(h1,h2)
        self.bn2 = nn.BatchNorm1d(h2)
        # ３階層目
        self.fc3 = nn.Linear(h2,h3)
        self.bn3 = nn.BatchNorm1d(h3)
        # 出力層
        self.fc4 = nn.Linear(h3,out_features)
    
    def forward(self, x_cat, x_cont):
        # カテゴリデータの次元削減
        embeddings = []
        for i,e in enumerate(self.embeds):
            embeddings.append(e(x_cat[:,i]))
        x = torch.cat(embeddings, 1)
        
        # バッチ正規化
        x_cont = self.bn_cont(x_cont)
        x = torch.cat([x, x_cont], 1)
        
        # １階層目
        x = F.relu(self.fc1(x))
        x = self.bn1(x)
        x = self.dropout1(x)
        # ２階層目
        x = F.relu(self.fc2(x))
        x = self.bn2(x)
        x = self.dropout2(x)
        x = F.relu(self.fc3(x))
        # ３階層目
        x = self.bn3(x)
        x = self.dropout3(x)
        # 出力層
        x = F.relu(self.fc4(x))
        return x

In [6]:
emb_size = [(40, 20), (8,4)]
model = CustomModel(emb_size, 3, p=0.3, h1=50, h2=50, h3=30, out_features=1)

In [7]:
model.load_state_dict(torch.load('GermanyCarModel.pt'))
model.eval()

CustomModel(
  (embeds): ModuleList(
    (0): Embedding(40, 20)
    (1): Embedding(8, 4)
  )
  (dropout1): Dropout(p=0.3, inplace=False)
  (dropout2): Dropout(p=0.3, inplace=False)
  (dropout3): Dropout(p=0.3, inplace=False)
  (bn_cont): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=27, out_features=50, bias=True)
  (bn1): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=50, out_features=50, bias=True)
  (bn2): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=50, out_features=30, bias=True)
  (bn3): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc4): Linear(in_features=30, out_features=1, bias=True)
)

In [None]:
from PySide2.QtWidgets import QApplication, QWidget, QLabel, QPushButton,QRadioButton,QButtonGroup, QVBoxLayout,QHBoxLayout,QLineEdit
import sys

class Window(QWidget):
    def __init__(self):
        super().__init__()

        self.setWindowTitle("Used car price prediction in Germany")
        self.setGeometry(0, 0, 600, 500)

        self.setRadioButton()

    def setRadioButton(self):
        vbox1 = QVBoxLayout()
        vbox2 = QVBoxLayout()
        vbox3 = QVBoxLayout()
        vbox4 = QVBoxLayout()

        self.labelBrand= QLabel('Vehicle brand ?')
        self.radioGroup = QButtonGroup()

        radioBtn0 = QRadioButton("alfa_romeo")
        self.radioGroup.addButton(radioBtn0,0)
        radioBtn1 = QRadioButton("audi")
        self.radioGroup.addButton(radioBtn1,1)
        radioBtn2 = QRadioButton("bmw")
        self.radioGroup.addButton(radioBtn2,2)
        radioBtn3 = QRadioButton("chevrolet")
        self.radioGroup.addButton(radioBtn3,3)
        radioBtn4 = QRadioButton("chrysler")
        self.radioGroup.addButton(radioBtn4,4)
        radioBtn5 = QRadioButton("citroen")
        self.radioGroup.addButton(radioBtn5,5)
        radioBtn6 = QRadioButton("dacia")
        self.radioGroup.addButton(radioBtn6,6)
        radioBtn7 = QRadioButton("daewoo")
        self.radioGroup.addButton(radioBtn7,7)
        radioBtn8 = QRadioButton("daihatsu")
        self.radioGroup.addButton(radioBtn8,8)
        radioBtn9 = QRadioButton("fiat")
        self.radioGroup.addButton(radioBtn9,9)
        radioBtn10 = QRadioButton("ford")
        self.radioGroup.addButton(radioBtn10,10)
        radioBtn11 = QRadioButton("honda")
        self.radioGroup.addButton(radioBtn11,11)
        radioBtn12 = QRadioButton("hyundai")
        self.radioGroup.addButton(radioBtn12,12)
        radioBtn13 = QRadioButton("jaguar")
        self.radioGroup.addButton(radioBtn13,13)
        radioBtn14 = QRadioButton("jeep")
        self.radioGroup.addButton(radioBtn14,14)
        radioBtn15 = QRadioButton("kia")
        self.radioGroup.addButton(radioBtn15,15)
        radioBtn16 = QRadioButton("lada")
        self.radioGroup.addButton(radioBtn16,16)
        radioBtn17 = QRadioButton("lancia")
        self.radioGroup.addButton(radioBtn17,17)
        radioBtn18 = QRadioButton("land_rover")
        self.radioGroup.addButton(radioBtn18,18)
        radioBtn19 = QRadioButton("mazda")
        self.radioGroup.addButton(radioBtn19,19)
        radioBtn20 = QRadioButton("mercedes_benz")
        self.radioGroup.addButton(radioBtn20,20)
        radioBtn21 = QRadioButton("mini")
        self.radioGroup.addButton(radioBtn21,21)
        radioBtn22 = QRadioButton("mitsubishi")
        self.radioGroup.addButton(radioBtn22,22)
        radioBtn23 = QRadioButton("nissan")
        self.radioGroup.addButton(radioBtn23,23)
        radioBtn24 = QRadioButton("opel")
        self.radioGroup.addButton(radioBtn24,24)
        radioBtn25 = QRadioButton("peugeot")
        self.radioGroup.addButton(radioBtn25,25)
        radioBtn26 = QRadioButton("porsche")
        self.radioGroup.addButton(radioBtn26,26)
        radioBtn27 = QRadioButton("renault")
        self.radioGroup.addButton(radioBtn27,27)
        radioBtn28 = QRadioButton("rover")
        self.radioGroup.addButton(radioBtn28,28)
        radioBtn29 = QRadioButton("saab")
        self.radioGroup.addButton(radioBtn29,29)
        radioBtn30 = QRadioButton("seat")
        self.radioGroup.addButton(radioBtn30,30)
        radioBtn31 = QRadioButton("skoda")
        self.radioGroup.addButton(radioBtn31,31)
        radioBtn32 = QRadioButton("smart")
        self.radioGroup.addButton(radioBtn32,32)
        radioBtn33 = QRadioButton("sonstige_autos")
        self.radioGroup.addButton(radioBtn33,33)
        radioBtn34 = QRadioButton("subaru")
        self.radioGroup.addButton(radioBtn34,34)
        radioBtn35 = QRadioButton("suzuki")
        self.radioGroup.addButton(radioBtn35,35)
        radioBtn36 = QRadioButton("toyota")
        self.radioGroup.addButton(radioBtn36,36)
        radioBtn37 = QRadioButton("trabant")
        self.radioGroup.addButton(radioBtn37,37)
        radioBtn38 = QRadioButton("volkswagen")
        self.radioGroup.addButton(radioBtn38,38)
        radioBtn39 = QRadioButton("volvo")
        self.radioGroup.addButton(radioBtn39,39)
        
        
        radioBtn0.setChecked(True)

        btn = QPushButton("Predict Price")
        btn.clicked.connect(self.buttoncallback)
        
        vbox1.addWidget(self.labelBrand)
        vbox1.addWidget(radioBtn0)
        vbox1.addWidget(radioBtn1)
        vbox1.addWidget(radioBtn2)
        vbox1.addWidget(radioBtn3)
        vbox1.addWidget(radioBtn4)
        vbox1.addWidget(radioBtn5)
        vbox1.addWidget(radioBtn6)
        vbox1.addWidget(radioBtn7)
        vbox1.addWidget(radioBtn8)
        vbox1.addWidget(radioBtn9)
        vbox1.addWidget(radioBtn10)
        vbox1.addWidget(radioBtn11)
        vbox1.addWidget(radioBtn12)
        vbox1.addWidget(radioBtn13)
        vbox1.addWidget(radioBtn14)
        vbox1.addWidget(radioBtn15)
        vbox1.addWidget(radioBtn16)
        vbox1.addWidget(radioBtn17)
        vbox1.addWidget(radioBtn18)
        vbox1.addWidget(radioBtn19)
        vbox1.addWidget(radioBtn20)
        vbox2.addWidget(radioBtn21)
        vbox2.addWidget(radioBtn22)
        vbox2.addWidget(radioBtn23)
        vbox2.addWidget(radioBtn24)
        vbox2.addWidget(radioBtn25)
        vbox2.addWidget(radioBtn26)
        vbox2.addWidget(radioBtn27)
        vbox2.addWidget(radioBtn28)
        vbox2.addWidget(radioBtn29)
        vbox2.addWidget(radioBtn30)
        vbox2.addWidget(radioBtn31)
        vbox2.addWidget(radioBtn32)
        vbox2.addWidget(radioBtn33)
        vbox2.addWidget(radioBtn34)
        vbox2.addWidget(radioBtn35)
        vbox2.addWidget(radioBtn36)
        vbox2.addWidget(radioBtn37)
        vbox2.addWidget(radioBtn38)
        vbox2.addWidget(radioBtn39)
        
        vbox1.addWidget(btn)   
        
        
        self.labelVehicleType= QLabel('Vehicle Type ?')
        labeldummy= QLabel('')
        
        self.radioGroup2 = QButtonGroup()
        
        radioBtnVT0 = QRadioButton("andere")
        self.radioGroup2.addButton(radioBtnVT0,0)
        radioBtnVT1 = QRadioButton("bus")
        self.radioGroup2.addButton(radioBtnVT1,1)
        radioBtnVT2 = QRadioButton("cabrio")
        self.radioGroup2.addButton(radioBtnVT2,2)
        radioBtnVT3 = QRadioButton("coupe")
        self.radioGroup2.addButton(radioBtnVT3,3)
        radioBtnVT4 = QRadioButton("kleinwagen")
        self.radioGroup2.addButton(radioBtnVT4,4)
        radioBtnVT5 = QRadioButton("kombi")
        self.radioGroup2.addButton(radioBtnVT5,5)
        radioBtnVT6 = QRadioButton("limousine")
        self.radioGroup2.addButton(radioBtnVT6,6)
        radioBtnVT7 = QRadioButton("suv")
        self.radioGroup2.addButton(radioBtnVT7,7)

        radioBtnVT0.setChecked(True)
        

        
        vbox3.addWidget(self.labelVehicleType)
        vbox3.addWidget(radioBtnVT0)
        vbox3.addWidget(radioBtnVT1)
        vbox3.addWidget(radioBtnVT2)
        vbox3.addWidget(radioBtnVT3)
        vbox3.addWidget(radioBtnVT4)
        vbox3.addWidget(radioBtnVT5)
        vbox3.addWidget(radioBtnVT6)
        vbox3.addWidget(radioBtnVT7)
        vbox3.addWidget(labeldummy)        
        

        self.labelPower= QLabel('Power[PS]')
        self.labelKilo= QLabel('Mileage[km]')
        self.labelYearRegis= QLabel('Year of Registration')
        
        self.editPower = QLineEdit("0",self)
        self.editKilo = QLineEdit("0",self)
        self.editYearRegis = QLineEdit("0",self)
        
        
        self.labelTitle= QLabel('Prediction result[€]')
        self.labelPredict= QLabel('-',self)
        self.labelTitleYen= QLabel('Prediction result[¥]')
        self.labelPredictYen= QLabel('-',self)
        
        vbox4.addWidget(self.labelPower)
        vbox4.addWidget(self.editPower)
        vbox4.addWidget(self.labelKilo)
        vbox4.addWidget(self.editKilo)
        vbox4.addWidget(self.labelYearRegis)
        vbox4.addWidget(self.editYearRegis)
        vbox4.addWidget(self.labelTitle)
        vbox4.addWidget(self.labelPredict)
        vbox4.addWidget(self.labelTitleYen)
        vbox4.addWidget(self.labelPredictYen)

        
        parentLayout = QHBoxLayout()
        parentLayout.addLayout(vbox1)
        parentLayout.addLayout(vbox2)
        parentLayout.addLayout(vbox3)
        parentLayout.addLayout(vbox4)
        self.setLayout(parentLayout)


    def buttoncallback(self):
        
        self.labelPredict.setText("-")
        xcats = [[self.radioGroup.checkedId(), self.radioGroup2.checkedId()]]
        xcats = torch.tensor(xcats, dtype=torch.int64)
        xconts = [[float(self.editPower.text())/100, float(self.editKilo.text())/1000,  float(self.editYearRegis.text())/1000]]
        xconts = torch.tensor(xconts, dtype=torch.float)

        with torch.no_grad():
            z = model(xcats, xconts)

            self.labelPredict.setText(str("{:,.2f}".format(z.item()*100)))
            self.labelPredictYen.setText(str("{:,.2f}".format(z.item()*100*118)))
            self.repaint()
            
            

myApp = QApplication.instance()
if myApp is None: 
    myApp = QApplication(sys.argv)

window = Window()
window.show()

myApp.exec_()
sys.exit(0)
