Skip to content

Commit

Permalink
:crowmageddon: probably broke everything :crowmageddon:
Browse files Browse the repository at this point in the history
  • Loading branch information
pjreddie committed Apr 10, 2017
1 parent 179ed8e commit 8d9ed0a
Show file tree
Hide file tree
Showing 66 changed files with 1,077 additions and 1,010 deletions.
4 changes: 2 additions & 2 deletions cfg/coco.data
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
classes= 80
train = /home/pjreddie/data/coco/trainvalno5k.txt
valid = coco_testdev
#valid = data/coco_val_5k.list
#valid = coco_testdev
valid = data/coco_val_5k.list
names = data/coco.names
backup = /home/pjreddie/backup/
eval=coco
Expand Down
4 changes: 2 additions & 2 deletions cfg/yolo.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ subdivisions=1
# Training
# batch=64
# subdivisions=8
height=416
width=416
height=608
width=608
channels=3
momentum=0.9
decay=0.0005
Expand Down
16 changes: 8 additions & 8 deletions src/activation_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,29 @@ layer make_activation_layer(int batch, int inputs, ACTIVATION activation)
return l;
}

void forward_activation_layer(layer l, network_state state)
void forward_activation_layer(layer l, network net)
{
copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1);
copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1);
activate_array(l.output, l.outputs*l.batch, l.activation);
}

void backward_activation_layer(layer l, network_state state)
void backward_activation_layer(layer l, network net)
{
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
copy_cpu(l.outputs*l.batch, l.delta, 1, state.delta, 1);
copy_cpu(l.outputs*l.batch, l.delta, 1, net.delta, 1);
}

#ifdef GPU

void forward_activation_layer_gpu(layer l, network_state state)
void forward_activation_layer_gpu(layer l, network net)
{
copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
copy_ongpu(l.outputs*l.batch, net.input_gpu, 1, l.output_gpu, 1);
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
}

void backward_activation_layer_gpu(layer l, network_state state)
void backward_activation_layer_gpu(layer l, network net)
{
gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, net.delta_gpu, 1);
}
#endif
8 changes: 4 additions & 4 deletions src/activation_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

layer make_activation_layer(int batch, int inputs, ACTIVATION activation);

void forward_activation_layer(layer l, network_state state);
void backward_activation_layer(layer l, network_state state);
void forward_activation_layer(layer l, network net);
void backward_activation_layer(layer l, network net);

#ifdef GPU
void forward_activation_layer_gpu(layer l, network_state state);
void backward_activation_layer_gpu(layer l, network_state state);
void forward_activation_layer_gpu(layer l, network net);
void backward_activation_layer_gpu(layer l, network net);
#endif

#endif
Expand Down
8 changes: 4 additions & 4 deletions src/avgpool_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void resize_avgpool_layer(avgpool_layer *l, int w, int h)
l->inputs = h*w*l->c;
}

void forward_avgpool_layer(const avgpool_layer l, network_state state)
void forward_avgpool_layer(const avgpool_layer l, network net)
{
int b,i,k;

Expand All @@ -47,14 +47,14 @@ void forward_avgpool_layer(const avgpool_layer l, network_state state)
l.output[out_index] = 0;
for(i = 0; i < l.h*l.w; ++i){
int in_index = i + l.h*l.w*(k + b*l.c);
l.output[out_index] += state.input[in_index];
l.output[out_index] += net.input[in_index];
}
l.output[out_index] /= l.h*l.w;
}
}
}

void backward_avgpool_layer(const avgpool_layer l, network_state state)
void backward_avgpool_layer(const avgpool_layer l, network net)
{
int b,i,k;

Expand All @@ -63,7 +63,7 @@ void backward_avgpool_layer(const avgpool_layer l, network_state state)
int out_index = k + b*l.c;
for(i = 0; i < l.h*l.w; ++i){
int in_index = i + l.h*l.w*(k + b*l.c);
state.delta[in_index] += l.delta[out_index] / (l.h*l.w);
net.delta[in_index] += l.delta[out_index] / (l.h*l.w);
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/avgpool_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ typedef layer avgpool_layer;
image get_avgpool_image(avgpool_layer l);
avgpool_layer make_avgpool_layer(int batch, int w, int h, int c);
void resize_avgpool_layer(avgpool_layer *l, int w, int h);
void forward_avgpool_layer(const avgpool_layer l, network_state state);
void backward_avgpool_layer(const avgpool_layer l, network_state state);
void forward_avgpool_layer(const avgpool_layer l, network net);
void backward_avgpool_layer(const avgpool_layer l, network net);

#ifdef GPU
void forward_avgpool_layer_gpu(avgpool_layer l, network_state state);
void backward_avgpool_layer_gpu(avgpool_layer l, network_state state);
void forward_avgpool_layer_gpu(avgpool_layer l, network net);
void backward_avgpool_layer_gpu(avgpool_layer l, network net);
#endif

#endif
Expand Down
8 changes: 4 additions & 4 deletions src/avgpool_layer_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ __global__ void backward_avgpool_layer_kernel(int n, int w, int h, int c, float
}
}

extern "C" void forward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
extern "C" void forward_avgpool_layer_gpu(avgpool_layer layer, network net)
{
size_t n = layer.c*layer.batch;

forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.input, layer.output_gpu);
forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, net.input_gpu, layer.output_gpu);
check_error(cudaPeekAtLastError());
}

extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network net)
{
size_t n = layer.c*layer.batch;

backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.delta, layer.delta_gpu);
backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, net.delta_gpu, layer.delta_gpu);
check_error(cudaPeekAtLastError());
}

64 changes: 36 additions & 28 deletions src/batchnorm_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ void resize_batchnorm_layer(layer *layer, int w, int h)
fprintf(stderr, "Not implemented\n");
}

void forward_batchnorm_layer(layer l, network_state state)
void forward_batchnorm_layer(layer l, network net)
{
if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1);
if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1);
if(l.type == CONNECTED){
l.out_c = l.outputs;
l.out_h = l.out_w = 1;
}
if(state.train){
copy_cpu(l.outputs*l.batch, l.output, 1, l.x, 1);
if(net.train){
mean_cpu(l.output, l.batch, l.out_c, l.out_h*l.out_w, l.mean);
variance_cpu(l.output, l.mean, l.batch, l.out_c, l.out_h*l.out_w, l.variance);

Expand All @@ -148,7 +149,6 @@ void forward_batchnorm_layer(layer l, network_state state)
scal_cpu(l.out_c, .99, l.rolling_variance, 1);
axpy_cpu(l.out_c, .01, l.variance, 1, l.rolling_variance, 1);

copy_cpu(l.outputs*l.batch, l.output, 1, l.x, 1);
normalize_cpu(l.output, l.mean, l.variance, l.batch, l.out_c, l.out_h*l.out_w);
copy_cpu(l.outputs*l.batch, l.output, 1, l.x_norm, 1);
} else {
Expand All @@ -158,8 +158,12 @@ void forward_batchnorm_layer(layer l, network_state state)
add_bias(l.output, l.biases, l.batch, l.out_c, l.out_h*l.out_w);
}

void backward_batchnorm_layer(const layer l, network_state state)
void backward_batchnorm_layer(layer l, network net)
{
if(!net.train){
l.mean = l.rolling_mean;
l.variance = l.rolling_variance;
}
backward_bias(l.bias_updates, l.delta, l.batch, l.out_c, l.out_w*l.out_h);
backward_scale_cpu(l.x_norm, l.delta, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates);

Expand All @@ -168,7 +172,7 @@ void backward_batchnorm_layer(const layer l, network_state state)
mean_delta_cpu(l.delta, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta);
variance_delta_cpu(l.x, l.delta, l.mean, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta);
normalize_delta_cpu(l.x, l.mean, l.variance, l.mean_delta, l.variance_delta, l.batch, l.out_c, l.out_w*l.out_h, l.delta);
if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, l.delta, 1, state.delta, 1);
if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, l.delta, 1, net.delta, 1);
}

#ifdef GPU
Expand All @@ -186,35 +190,35 @@ void push_batchnorm_layer(layer l)
cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.c);
}

void forward_batchnorm_layer_gpu(layer l, network_state state)
void forward_batchnorm_layer_gpu(layer l, network net)
{
if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, net.input_gpu, 1, l.output_gpu, 1);
if(l.type == CONNECTED){
l.out_c = l.outputs;
l.out_h = l.out_w = 1;
}
if (state.train) {
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
if (net.train) {
#ifdef CUDNN
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
float one = 1;
float zero = 0;
cudnnBatchNormalizationForwardTraining(cudnn_handle(),
CUDNN_BATCHNORM_SPATIAL,
&one,
&zero,
l.dstTensorDesc,
l.x_gpu,
l.dstTensorDesc,
l.output_gpu,
l.normTensorDesc,
l.scales_gpu,
l.biases_gpu,
.01,
l.rolling_mean_gpu,
l.rolling_variance_gpu,
.00001,
l.mean_gpu,
l.variance_gpu);
CUDNN_BATCHNORM_SPATIAL,
&one,
&zero,
l.dstTensorDesc,
l.x_gpu,
l.dstTensorDesc,
l.output_gpu,
l.normTensorDesc,
l.scales_gpu,
l.biases_gpu,
.01,
l.rolling_mean_gpu,
l.rolling_variance_gpu,
.00001,
l.mean_gpu,
l.variance_gpu);
#else
fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
Expand All @@ -239,8 +243,12 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)

}

void backward_batchnorm_layer_gpu(const layer l, network_state state)
void backward_batchnorm_layer_gpu(layer l, network net)
{
if(!net.train){
l.mean_gpu = l.rolling_mean_gpu;
l.variance_gpu = l.rolling_variance_gpu;
}
#ifdef CUDNN
float one = 1;
float zero = 0;
Expand Down Expand Up @@ -274,6 +282,6 @@ void backward_batchnorm_layer_gpu(const layer l, network_state state)
fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta_gpu);
normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.delta_gpu);
#endif
if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, net.delta_gpu, 1);
}
#endif
8 changes: 4 additions & 4 deletions src/batchnorm_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
#include "network.h"

layer make_batchnorm_layer(int batch, int w, int h, int c);
void forward_batchnorm_layer(layer l, network_state state);
void backward_batchnorm_layer(layer l, network_state state);
void forward_batchnorm_layer(layer l, network net);
void backward_batchnorm_layer(layer l, network net);

#ifdef GPU
void forward_batchnorm_layer_gpu(layer l, network_state state);
void backward_batchnorm_layer_gpu(layer l, network_state state);
void forward_batchnorm_layer_gpu(layer l, network net);
void backward_batchnorm_layer_gpu(layer l, network net);
void pull_batchnorm_layer(layer l);
void push_batchnorm_layer(layer l);
#endif
Expand Down
2 changes: 1 addition & 1 deletion src/blas_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ __global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (index >= N) return;

x[index] = x[index] - (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps));
x[index] = x[index] + (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps));
//if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
}

Expand Down
2 changes: 1 addition & 1 deletion src/classifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
save_weights(net, buff);
}
if(get_current_batch(net)%100 == 0){
if(get_current_batch(net)%1000 == 0){
char buff[256];
sprintf(buff, "%s/%s.backup",backup_directory,base);
save_weights(net, buff);
Expand Down
Loading

2 comments on commit 8d9ed0a

@sivagnanamn
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YOLO V1 training broken. "Detection Avg IOU", "Pos Cat", "All Cat" etc are always 0.000000 after this commit.

Loaded: 4.542056 seconds
Detection Avg IOU: 0.000000, Pos Cat: 0.000000, All Cat: 0.000000, Pos Obj: 0.000000, Any Obj: 0.000000, count: 48
Detection Avg IOU: 0.000000, Pos Cat: 0.000000, All Cat: 0.000000, Pos Obj: 0.000000, Any Obj: 0.000000, count: 52
1: 44.408333, 44.408333 avg, 0.000500 rate, 7.206092 seconds, 64 images
Loaded: 0.000018 seconds
Detection Avg IOU: 0.000000, Pos Cat: 0.000000, All Cat: 0.000000, Pos Obj: 0.000000, Any Obj: 0.000000, count: 50
Detection Avg IOU: 0.000000, Pos Cat: 0.000000, All Cat: 0.000000, Pos Obj: 0.000000, Any Obj: 0.000000, count: 42
2: 44.872047, 44.454704 avg, 0.000500 rate, 8.551364 seconds, 128 images
Loaded: 0.000018 seconds

@Alexey-Kamenev
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a pull request for this issue.

Please sign in to comment.