In [None]:
#使用fastaiV2版本
from fastai.vision.all import *

In [None]:
#复制、解压train文件
from zipfile import ZipFile
with ZipFile('../input/tgs-salt-identification-challenge/train.zip', 'r') as zip_ref:
  zip_ref.extractall('')
#获得images和labels(在语义分割项目中一般叫masks)
path = Path('')
fnames = get_image_files(path/'images')
lbl_names = get_image_files(path/'masks')

In [None]:
#通过查看一对image+mask，确认图片读取正确
get_mask = lambda o:'./masks/'+str(o.stem)+'.png' #路径变化 后缀名变化

#img_fn = fnames[10]
#im = PILImage.create(img_fn)
#im.show(figsize=(5,5))

In [None]:
#打印配对的mask
#mask_fn = get_mask(img_fn)
#msk = PILMask.create(mask_fn)
#msk.show(figsize=(5,5), alpha=1)

In [None]:
#验证并获得所有mask的类别，对于语义分割来说，有几类就应该有几个结果
def n_codes(fnames, is_partial=True):
  "Gather the codes from a list of `fnames`"
  vals = set()
  if is_partial:
    random.shuffle(fnames)
    fnames = fnames[:]
  for fname in fnames:
    msk = np.array(PILMask.create(fname))
    for val in np.unique(msk):
      if val not in vals:
        vals.add(val)
  vals = list(vals)
  p2c = dict()
  for i,val in enumerate(vals):
    p2c[i] = vals[i]
  return p2c
p2c=n_codes(lbl_names)
p2c

In [None]:
#将label中的值按照0、1..进行归类
def get_mask2(fn,p2c=n_codes(lbl_names)):
    fn = './masks/'+str(fn.stem)+'.png'
    msk = np.array(PILMask.create(fn))
    mx = np.max(msk)
    for i,val in enumerate(p2c):
        msk[msk == p2c[i]] = val
    return PILMask.create(msk)

In [None]:
#生成DataBlock
binary = DataBlock(blocks=(ImageBlock, MaskBlock( ['Background', 'salt'])),     #block类型为图片和mask
                   get_items=get_image_files,                                   #x的获取方法为get_image_files
                   splitter=RandomSplitter(),                                   #随机分割
                   get_y=get_mask2,                                             #y的获取方法
                   item_tfms=Resize(128),                                       #语义分割项目，一般只进行尺度变换
                   batch_tfms=[Normalize.from_stats(*imagenet_stats)])          #batch_tfms一般都这样设定

In [None]:
#读取图片,并显示样例
dls = binary.dataloaders(path/'images')
dls.show_batch(cmap='Greens', vmin=0, vmax=1)

In [None]:
#选择模型，对于语义分割来说,metrics一般选择Dice;to_fp16()是混合精度模型，能够提高训练速度
learn = unet_learner(dls,resnet34,metrics = Dice).to_fp16()

In [None]:
#选择lr并开始训练。DICE值越高效果越好
learn.unfreeze()
learn.fit_flat_cos(10)
#learn.fit_flat_cos(10,learn.lr_find())
learn.recorder.plot_loss()

In [None]:
#处理并输出测试结果
SUBMIT_FOLDER='test'
#os.mkdir(SUBMIT_FOLDER)
with ZipFile('../input/tgs-salt-identification-challenge/test.zip', 'r') as zip_ref:
    zip_ref.extractall(SUBMIT_FOLDER)
#独缺sample_submission
submit_mask = pd.read_csv('/kaggle/input/tgs-salt-identification-challenge/sample_submission.csv')

In [None]:
#验证test目录和csv数据一致
#for idx,name in (enumerate(submit_mask['id'].iloc[:])):
#    name =  'test/images/'+ str(name)+'.png'
#    if(not(os.path.exists(name))):
#        print (idx,name)
#print('done')

In [None]:
#单张推断，主要用于结果比对
submit_index = 1
submit_image = 'test/images/'+ str(submit_mask['id'].iloc[submit_index])+'.png'
submit_predict = learn.predict(submit_image)
submit_np = np.array(submit_predict[0])
submit_np.resize(101,101)
submit_np

In [None]:
#按照sample_submission的顺序进行处理
test_csv_names = 'test/images/'+submit_mask['id']+'.png'
test_csv_names

In [None]:
#批量推断
test_dl = learn.dls.test_dl(test_csv_names)
preds = learn.get_preds(dl=test_dl)

In [None]:
#结果比对
submit_np2 = np.array(preds[0][submit_index]<0.5).astype(np.uint8)
submit_np2.resize(101,101)
print ((submit_np2 == submit_np).all()) #判断两个矩阵是否一致

In [None]:
def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
from tqdm import tqdm
for idx,name in enumerate(tqdm(submit_mask['id'].iloc[:])):
    aResult = np.array(preds[0][idx]<0.5).astype(np.uint8)
    aResult.resize(101,101)
    submit_mask['rle_mask'][idx]=rle_encode(aResult)

In [None]:
#输出结果
import time
from datetime import datetime
import pytz
tz = pytz.timezone('Asia/Shanghai') #东八区
csv_str = datetime.fromtimestamp(int(time.time()),pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d-%H-%M-%S')+'.csv'
submit_mask.to_csv(csv_str, index=False, header=True)

In [None]:
import os
os.chdir('/kaggle/working')
print(os.getcwd())
print(os.listdir("/kaggle/working"))
from IPython.display import FileLink
FileLink(csv_str)